1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
|
local api = vim.api
local ts = vim.treesitter
local utils = require'nvim-treesitter.utils'
local parsers = require'nvim-treesitter.parsers'
local predicates = require'nvim-treesitter.query_predicates'
local M = {}
local default_dict = {
__index = function(table, key)
local exists = rawget(table, key)
if not exists then
table[key] = {}
end
return rawget(table, key)
end
}
local query_cache = setmetatable({}, default_dict)
-- Some treesitter grammars extend others.
-- We can use that to import the queries of the base language
M.base_language_map = {
cpp = {'c'},
typescript = {'javascript'},
tsx = {'typescript', 'javascript'},
}
M.query_extensions = {
javascript = { 'jsx' },
tsx = {'javascript.jsx'}
}
M.built_in_query_groups = {'highlights', 'locals', 'textobjects'}
-- Creates a function that checks whether a certain query exists
-- for a specific language.
local function get_query_guard(query)
return function(lang)
return M.get_query(lang, query) ~= nil
end
end
for _, query in ipairs(M.built_in_query_groups) do
M["has_" .. query] = get_query_guard(query)
end
local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
vim.list_extend(contents, vim.fn.readfile(filename))
end
return table.concat(contents, '\n')
end
local function update_cached_matches(bufnr, changed_tick, query_group)
query_cache[query_group][bufnr] = {tick=changed_tick, cache=( M.collect_group_results(bufnr, query_group) or {} )}
end
function M.get_matches(bufnr, query_group)
local bufnr = bufnr or api.nvim_get_current_buf()
local cached_local = query_cache[query_group][bufnr]
if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then
update_cached_matches(bufnr,api.nvim_buf_get_changedtick(bufnr), query_group)
end
return query_cache[query_group][bufnr].cache
end
function M.get_query(lang, query_name)
local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true)
local query_string = ''
if #query_files > 0 then
query_string = read_query_files(query_files) .. "\n" .. query_string
end
for _, base_lang in ipairs(M.base_language_map[lang] or {}) do
local base_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', base_lang, query_name), true)
if base_files and #base_files > 0 then
query_string = read_query_files(base_files) .. "\n" .. query_string
end
end
local extensions = M.query_extensions[lang]
for _, ext in ipairs(extensions or {}) do
local l = lang
local e = ext
if e:match('%.') ~= nil then
l = e:match('.*%.'):sub(0, -2)
e = e:match('%..*'):sub(2, -1)
end
local ext_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', l, e), true)
if ext_files and #ext_files > 0 then
query_string = read_query_files(ext_files) .. "\n" .. query_string
end
end
if #query_string > 0 then
return ts.parse_query(lang, query_string)
end
end
function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row)
-- A function that splits a string on '.'
local function split(string)
local t = {}
for str in string.gmatch(string, "([^.]+)") do
table.insert(t, str)
end
return t
end
-- Given a path (i.e. a List(String)) this functions inserts value at path
local function insert_to_path(object, path, value)
local curr_obj = object
for index=1,(#path -1) do
if curr_obj[path[index]] == nil then
curr_obj[path[index]] = {}
end
curr_obj = curr_obj[path[index]]
end
curr_obj[path[#path]] = value
end
local matches = query:iter_matches(qnode, bufnr, start_row, end_row)
local function iterator()
local pattern, match = matches()
if pattern ~= nil then
local prepared_match = {}
-- Extract capture names from each match
for id, node in pairs(match) do
local name = query.captures[id] -- name of the capture in the query
if name ~= nil then
local path = split(name)
insert_to_path(prepared_match, path, { node=node })
end
end
-- Add some predicates for testing
local preds = query.info.patterns[pattern]
if preds then
for _, pred in pairs(preds) do
-- functions
if pred[1] == "set!" and type(pred[2]) == "string" then
insert_to_path(prepared_match, split(pred[2]), pred[3])
end
-- predicates
if type(pred[1]) == 'string' then
if not predicates.check_predicate(query, prepared_match, pred) or
not predicates.check_negated_predicate(query, prepared_match, pred) then
return iterator()
end
end
end
end
return prepared_match
end
end
return iterator
end
--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type)
-- Works like M.get_references or M.get_scopes except you can choose the capture
-- Can also be a nested capture like @definition.function to get all nodes defining a function
function M.get_capture_matches(bufnr, capture_string, query_group)
if not string.sub(capture_string, 1,2) == '@' then
print('capture_string must start with "@"')
return
end
--remove leading "@"
capture_string = string.sub(capture_string, 2)
local matches = {}
for match in M.iter_group_results(bufnr, query_group) do
local insert = utils.get_at_path(match, capture_string)
if insert then
table.insert(matches, insert)
end
end
return matches
end
-- Iterates matches from a query file.
-- @param bufnr the buffer
-- @param query_group the query file to use
-- @param root the root node
function M.iter_group_results(bufnr, query_group, root)
local lang = parsers.get_buf_lang(bufnr)
if not lang then return end
local query = M.get_query(lang, query_group)
if not query then return end
local parser = parsers.get_parser(bufnr, lang)
if not parser then return end
local root = root or parser:parse():root()
local start_row, _, end_row, _ = root:range()
return M.iter_prepared_matches(query, root, bufnr, start_row, end_row)
end
function M.collect_group_results(bufnr, query_group, root)
local matches = {}
for prepared_match in M.iter_group_results(bufnr, query_group, root) do
table.insert(matches, prepared_match)
end
return matches
end
return M
|