1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
|
local api = vim.api
local tsq = require'vim.treesitter.query'
local tsrange = require'nvim-treesitter.tsrange'
local utils = require'nvim-treesitter.utils'
local parsers = require'nvim-treesitter.parsers'
local caching = require'nvim-treesitter.caching'
local M = {}
local EMPTY_ITER = function() end
M.built_in_query_groups = {'highlights', 'locals', 'folds', 'indents'}
-- Creates a function that checks whether a given query exists
-- for a specific language.
local function get_query_guard(query)
return function(lang)
return M.has_query_files(lang, query)
end
end
for _, query in ipairs(M.built_in_query_groups) do
M["has_" .. query] = get_query_guard(query)
end
function M.available_query_groups()
local query_files = api.nvim_get_runtime_file('queries/*/*.scm', true)
local groups = {}
for _, f in ipairs(query_files) do
groups[vim.fn.fnamemodify(f, ':t:r')] = true
end
local list = {}
for k, _ in pairs(groups) do
table.insert(list, k)
end
return list
end
do
local query_cache = caching.create_buffer_cache()
local function update_cached_matches(bufnr, changed_tick, query_group)
query_cache.set(query_group, bufnr, {
tick = changed_tick,
cache= M.collect_group_results(bufnr, query_group) or {}
})
end
function M.get_matches(bufnr, query_group)
bufnr = bufnr or api.nvim_get_current_buf()
local cached_local = query_cache.get(query_group, bufnr)
if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then
update_cached_matches(bufnr,api.nvim_buf_get_changedtick(bufnr), query_group)
end
return query_cache.get(query_group, bufnr).cache
end
end
local function runtime_queries(lang, query_name)
return api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true) or {}
end
function M.has_query_files(lang, query_name)
local files = runtime_queries(lang, query_name)
return files and #files > 0
end
do
local mt = {}
mt.__index = function(tbl, key)
if rawget(tbl, key) == nil then
rawset(tbl, key, {})
end
return rawget(tbl, key)
end
-- cache will auto set the table for each lang if it is nil
local cache = setmetatable({}, mt)
--- Same as `vim.treesitter.query` except will return cached values
function M.get_query(lang, query_name)
if cache[lang][query_name] == nil then
cache[lang][query_name] = tsq.get_query(lang, query_name)
end
return cache[lang][query_name]
end
--- Invalidates the query file cache.
--- If lang and query_name is both present, will reload for only the lang and query_name.
--- If only lang is present, will reload all query_names for that lang
--- If none are present, will reload everything
function M.invalidate_query_cache(lang, query_name)
if lang and query_name then
cache[lang][query_name] = nil
elseif lang and not query_name then
for query_name, _ in pairs(cache[lang]) do
M.invalidate_query_cache(lang, query_name)
end
elseif not lang and not query_name then
for lang, _ in pairs(cache) do
for query_name, _ in pairs(cache[lang]) do
M.invalidate_query_cache(lang, query_name)
end
end
else
error("Cannot have query_name by itself!")
end
end
end
--- This function is meant for an autocommand and not to be used. Only use if file is a query file.
function M.invalidate_query_file(fname)
local fnamemodify = vim.fn.fnamemodify
M.invalidate_query_cache(fnamemodify(fname, ':p:h:t'), fnamemodify(fname, ':t:r'))
end
function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
-- A function that splits a string on '.'
local function split(string)
local t = {}
for str in string.gmatch(string, "([^.]+)") do
table.insert(t, str)
end
return t
end
-- Given a path (i.e. a List(String)) this functions inserts value at path
local function insert_to_path(object, path, value)
local curr_obj = object
for index=1,(#path -1) do
if curr_obj[path[index]] == nil then
curr_obj[path[index]] = {}
end
curr_obj = curr_obj[path[index]]
end
curr_obj[path[#path]] = value
end
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
local function iterator()
local pattern, match = matches()
if pattern ~= nil then
local prepared_match = {}
-- Extract capture names from each match
for id, node in pairs(match) do
local name = query.captures[id] -- name of the capture in the query
if name ~= nil then
local path = split(name..'.node')
insert_to_path(prepared_match, path, node)
end
end
-- Add some predicates for testing
local preds = query.info.patterns[pattern]
if preds then
for _, pred in pairs(preds) do
-- functions
if pred[1] == "set!" and type(pred[2]) == "string" then
insert_to_path(prepared_match, split(pred[2]), pred[3])
end
if pred[1] == "make-range!" and type(pred[2]) == "string" and #pred == 4 then
insert_to_path(prepared_match, split(pred[2]..'.node'),
tsrange.TSRange.from_nodes(bufnr, match[pred[3]], match[pred[4]]))
end
end
end
return prepared_match
end
end
return iterator
end
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
-- Can also be a nested capture like @definition.function to get all nodes defining a function
function M.get_capture_matches(bufnr, capture_string, query_group, root, lang)
if not string.sub(capture_string, 1, 1) == '@' then
print('capture_string must start with "@"')
return
end
--remove leading "@"
capture_string = string.sub(capture_string, 2)
local matches = {}
for match in M.iter_group_results(bufnr, query_group, root, lang) do
local insert = utils.get_at_path(match, capture_string)
if insert then
table.insert(matches, insert)
end
end
return matches
end
function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root)
if string.sub(capture_string, 1, 1) == '@' then
--remove leading "@"
capture_string = string.sub(capture_string, 2)
end
local best
local best_score
for maybe_match in M.iter_group_results(bufnr, query_group, root) do
local match = utils.get_at_path(maybe_match, capture_string)
if match and filter_predicate(match) then
local current_score = scoring_function(match)
if not best then
best = match
best_score = current_score
end
if current_score > best_score then
best = match
best_score = current_score
end
end
end
return best
end
-- Iterates matches from a query file.
-- @param bufnr the buffer
-- @param query_group the query file to use
-- @param root the root node
-- @param root the root node lang, if known
function M.iter_group_results(bufnr, query_group, root, root_lang)
local buf_lang = parsers.get_buf_lang(bufnr)
if not buf_lang then return EMPTY_ITER end
local parser = parsers.get_parser(bufnr, buf_lang)
if not parser then return EMPTY_ITER end
if not root then
local first_tree = parser:trees()[1]
if first_tree then
root = first_tree:root()
end
end
if not root then return EMPTY_ITER end
local range = {root:range()}
if not root_lang then
local lang_tree = parser:language_for_range(range)
if lang_tree then
root_lang = lang_tree:lang()
end
end
if not root_lang then return EMPTY_ITER end
local query = M.get_query(root_lang, query_group)
if not query then return EMPTY_ITER end
-- The end row is exclusive so we need to add 1 to it.
return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1)
end
function M.collect_group_results(bufnr, query_group, root, lang)
local matches = {}
for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do
table.insert(matches, prepared_match)
end
return matches
end
--- Same as get_capture_matches except this will recursively get matches for every language in the tree.
-- @param bufnr The bufnr
-- @param capture_or_fn The capture to get. If a function is provided then that
-- function will be used to resolve both the capture and query argument.
-- The function can return `nil` to ignore that tree.
-- @param query_type The query to get the capture from. This is ignore if a function is provided
-- for the captuer argument.
function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type)
local type_fn = type(capture_or_fn) == 'function'
and capture_or_fn
or function()
return capture_or_fn, query_type
end
local parser = parsers.get_parser(bufnr)
local matches = {}
if parser then
parser:for_each_tree(function(tree, lang_tree)
local lang = lang_tree:lang()
local capture, type_ = type_fn(lang, tree, lang_tree)
if capture then
vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang))
end
end)
end
return matches
end
return M
|