diff options
Diffstat (limited to 'lua/nvim-treesitter/query.lua')
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 113 |
1 files changed, 97 insertions, 16 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 69e52e1e..896b11da 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -1,25 +1,21 @@ local api = vim.api local ts = vim.treesitter +local utils = require'nvim-treesitter.utils' +local parsers = require'nvim-treesitter.parsers' local M = {} -local function read_query_files(filenames) - local contents = {} - - for _,filename in ipairs(filenames) do - vim.list_extend(contents, vim.fn.readfile(filename)) +local default_dict = { + __index = function(table, key) + local exists = rawget(table, key) + if not exists then + table[key] = {} + end + return rawget(table, key) end +} - return table.concat(contents, '\n') -end - --- Creates a function that checks whether a certain query exists --- for a specific language. -local function get_query_guard(query) - return function(lang) - return M.get_query(lang, query) ~= nil - end -end +local query_cache = setmetatable({}, default_dict) -- Some treesitter grammars extend others. -- We can use that to import the queries of the base language @@ -36,10 +32,42 @@ M.query_extensions = { M.built_in_query_groups = {'highlights', 'locals', 'textobjects'} +-- Creates a function that checks whether a certain query exists +-- for a specific language. +local function get_query_guard(query) + return function(lang) + return M.get_query(lang, query) ~= nil + end +end + for _, query in ipairs(M.built_in_query_groups) do M["has_" .. query] = get_query_guard(query) end +local function read_query_files(filenames) + local contents = {} + + for _,filename in ipairs(filenames) do + vim.list_extend(contents, vim.fn.readfile(filename)) + end + + return table.concat(contents, '\n') +end + +local function update_cached_matches(bufnr, changed_tick, query_group) + query_cache[query_group][bufnr] = {tick=changed_tick, cache=( M.collect_group_results(bufnr, query_group) or {} )} +end + +function M.get_matches(bufnr, query_group) + local bufnr = bufnr or api.nvim_get_current_buf() + local cached_local = query_cache[query_group][bufnr] + if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then + update_cached_matches(bufnr,api.nvim_buf_get_changedtick(bufnr), query_group) + end + + return query_cache[query_group][bufnr].cache +end + function M.get_query(lang, query_name) local query_files = api.nvim_get_runtime_file(string.format('queries/%s/%s.scm', lang, query_name), true) local query_string = '' @@ -84,7 +112,6 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) return t end - -- Given a path (i.e. a List(String)) this functions inserts value at path local function insert_to_path(object, path, value) local curr_obj = object @@ -131,4 +158,58 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) end end +--- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) +-- Works like M.get_references or M.get_scopes except you can choose the capture +-- Can also be a nested capture like @definition.function to get all nodes defining a function +function M.get_capture_matches(bufnr, capture_string, query_group) + if not string.sub(capture_string, 1,2) == '@' then + print('capture_string must start with "@"') + return + end + + --remove leading "@" + capture_string = string.sub(capture_string, 2) + + local matches = {} + for match in M.iter_group_results(bufnr, query_group) do + local insert = utils.get_at_path(match, capture_string) + + if insert then + table.insert(matches, insert) + end + end + return matches +end + +-- Iterates matches from a query file. +-- @param bufnr the buffer +-- @param query_group the query file to use +-- @param root the root node +function M.iter_group_results(bufnr, query_group, root) + + local lang = parsers.get_buf_lang(bufnr) + if not lang then return end + + local query = M.get_query(lang, query_group) + if not query then return end + + local parser = parsers.get_parser(bufnr, lang) + if not parser then return end + + local root = root or parser:parse():root() + local start_row, _, end_row, _ = root:range() + + return M.iter_prepared_matches(query, root, bufnr, start_row, end_row) +end + +function M.collect_group_results(bufnr, query_group, root) + local matches = {} + + for prepared_match in M.iter_group_results(bufnr, query_group, root) do + table.insert(matches, prepared_match) + end + + return matches +end + return M |
