summaryrefslogtreecommitdiff
path: root/lua/nvim-treesitter/query.lua
blob: 034df223db82ad39ff36a87168c9b9634503dbb9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
local api = vim.api
local ts = vim.treesitter
local utils = require'nvim-treesitter.utils'
local parsers = require'nvim-treesitter.parsers'

local M = {}

local default_dict = {
  __index = function(table, key)
    local exists = rawget(table, key)
    if not exists then
      table[key] = {}
    end
    return rawget(table, key)
  end
}

local query_cache = setmetatable({}, default_dict)

-- Some treesitter grammars extend others.
-- We can use that to import the queries of the base language
M.base_language_map = {
  cpp = {'c'},
  typescript = {'javascript'},
  tsx = {'typescript', 'javascript'},
}

M.query_extensions = {
  javascript = { 'jsx' },
  tsx = {'javascript.jsx'}
}

M.built_in_query_groups = {'highlights', 'locals', 'textobjects'}

-- Creates a function that checks whether a certain query exists
-- for a specific language.
local function get_query_guard(query)
  return function(lang)
    return M.has_query_files(lang, query)
  end
end

for _, query in ipairs(M.built_in_query_groups) do
  M["has_" .. query] = get_query_guard(query)
end

local function read_query_files(filenames)
  local contents = {}

  for _,filename in ipairs(filenames) do
    vim.list_extend(contents, vim.fn.readfile(filename))
  end

  return table.concat(contents, '\n')
end

local function update_cached_matches(bufnr, changed_tick, query_group)
  query_cache[query_group][bufnr] = {tick=changed_tick, cache=( M.collect_group_results(bufnr, query_group) or {} )}
end

function M.get_matches(bufnr, query_group)
  local bufnr = bufnr or api.nvim_get_current_buf()
  local cached_local = query_cache[query_group][bufnr]
  if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then
    update_cached_matches(bufnr,api.nvim_buf_get_changedtick(bufnr), query_group)
  end

  return query_cache[query_group][bufnr].cache
end

local function filter_files(file_list)
  local main = {}
  local after = {}

  for _, fname in ipairs(file_list) do
    -- Only get the name of the directory containing the queries directory
    if vim.fn.fnamemodify(fname, ":p:h:h:h:t") == "after" then
      table.insert(after, fname)
    -- The first one is the one with most priority
    elseif #main == 0 then
      main = { fname }
    end
  end

  vim.list_extend(main, after)

  return main
end

local function filtered_runtime_queries(lang, query_name)
  return filter_files(api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true) or {})
end

function M.get_query_files(lang, query_name)
  local query_files = {}
  local extensions = M.query_extensions[lang] or {}

  for _, ext in ipairs(extensions) do
    local l = lang
    local e = ext
    if e:match('%.') ~= nil then
       l = e:match('.*%.'):sub(0, -2)
       e = e:match('%..*'):sub(2, -1)
    end
    local ext_files = filtered_runtime_queries(l, e)
    vim.list_extend(query_files, ext_files)
  end

  for _, base_lang in ipairs(M.base_language_map[lang] or {}) do
    local base_files = filtered_runtime_queries(base_lang, query_name)
    vim.list_extend(query_files, base_files)
  end

  local lang_files = filtered_runtime_queries(lang, query_name)

  return vim.list_extend(query_files, lang_files)
end

function M.has_query_files(lang, query_name)
  local query_files = M.get_query_files(lang, query_name)

  return #query_files > 0
end

function M.get_query(lang, query_name)
  local query_files = M.get_query_files(lang, query_name)
  local query_string = read_query_files(query_files)

  if #query_string > 0 then
    return ts.parse_query(lang, query_string)
  end
end

function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
  -- A function that splits  a string on '.'
  local function split(string)
    local t = {}
    for str in string.gmatch(string, "([^.]+)") do
      table.insert(t, str)
    end

    return t
  end
  -- Given a path (i.e. a List(String)) this functions inserts value at path
  local function insert_to_path(object, path, value)
    local curr_obj = object

    for index=1,(#path -1) do
      if curr_obj[path[index]] == nil then
        curr_obj[path[index]] = {}
      end

      curr_obj = curr_obj[path[index]]
    end

    curr_obj[path[#path]] = value
  end

  local matches = query:iter_matches(qnode, bufnr, start_row, end_row)

  local function iterator()
    local pattern, match = matches()
    if pattern ~= nil then
      local prepared_match = {}

      -- Extract capture names from each match
      for id, node in pairs(match) do
        local name = query.captures[id] -- name of the capture in the query
        if name ~= nil then
          local path = split(name)
          insert_to_path(prepared_match, path, { node=node })
        end
      end

      -- Add some predicates for testing
      local preds = query.info.patterns[pattern]
      if preds then
        for _, pred in pairs(preds) do
          -- functions
          if pred[1] == "set!" and type(pred[2]) == "string" then
            insert_to_path(prepared_match, split(pred[2]), pred[3])
          end
        end
      end

      return prepared_match
    end
  end
  return iterator
end

--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
-- Can also be a nested capture like @definition.function to get all nodes defining a function
function M.get_capture_matches(bufnr, capture_string, query_group)
    if not string.sub(capture_string, 1,2) == '@' then
      print('capture_string must start with "@"')
      return
    end

    --remove leading "@"
    capture_string = string.sub(capture_string, 2)

    local matches = {}
    for match in M.iter_group_results(bufnr, query_group) do
      local insert = utils.get_at_path(match, capture_string)

      if insert then
        table.insert(matches, insert)
      end
    end
    return matches
end

function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function)
    if not string.sub(capture_string, 1,2) == '@' then
      api.nvim_err_writeln('capture_string must start with "@"')
      return
    end

    --remove leading "@"
    capture_string = string.sub(capture_string, 2)

    local best
    local best_score

    for maybe_match in M.iter_group_results(bufnr, query_group) do
      local match = utils.get_at_path(maybe_match, capture_string)

      if match and filter_predicate(match) then
        local current_score = scoring_function(match)
        if not best then
          best = match
          best_score = current_score
        end
        if current_score > best_score then
          best = match
          best_score = current_score
        end
      end
    end
    return best
end

-- Iterates matches from a query file.
-- @param bufnr the buffer
-- @param query_group the query file to use
-- @param root the root node
function M.iter_group_results(bufnr, query_group, root)
  local lang = parsers.get_buf_lang(bufnr)
  if not lang then return end

  local query = M.get_query(lang, query_group)
  if not query then return end

  local parser = parsers.get_parser(bufnr, lang)
  if not parser then return end

  local root = root or parser:parse():root()
  local start_row, _, end_row, _ = root:range()

  -- The end row is exclusive so we need to add 1 to it.
  return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1)
end

function M.collect_group_results(bufnr, query_group, root)
  local matches = {}

  for prepared_match in M.iter_group_results(bufnr, query_group, root) do
    table.insert(matches, prepared_match)
  end

  return matches
end

return M