diff options
Diffstat (limited to 'lua/nvim-treesitter/query.lua')
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 85 |
1 files changed, 69 insertions, 16 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index df3e7080..e6683139 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -7,6 +7,8 @@ local caching = require'nvim-treesitter.caching' local M = {} +local EMPTY_ITER = function() end + M.built_in_query_groups = {'highlights', 'locals', 'folds', 'indents'} -- Creates a function that checks whether a given query exists @@ -166,7 +168,7 @@ end --- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) -- Works like M.get_references or M.get_scopes except you can choose the capture -- Can also be a nested capture like @definition.function to get all nodes defining a function -function M.get_capture_matches(bufnr, capture_string, query_group) +function M.get_capture_matches(bufnr, capture_string, query_group, root, lang) if not string.sub(capture_string, 1,2) == '@' then print('capture_string must start with "@"') return @@ -176,7 +178,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group) capture_string = string.sub(capture_string, 2) local matches = {} - for match in M.iter_group_results(bufnr, query_group) do + for match in M.iter_group_results(bufnr, query_group, root, lang) do local insert = utils.get_at_path(match, capture_string) if insert then @@ -186,7 +188,7 @@ function M.get_capture_matches(bufnr, capture_string, query_group) return matches end -function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function) +function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root) if not string.sub(capture_string, 1,2) == '@' then api.nvim_err_writeln('capture_string must start with "@"') return @@ -198,7 +200,7 @@ function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, local best local best_score - for maybe_match in M.iter_group_results(bufnr, query_group) do + for maybe_match in M.iter_group_results(bufnr, query_group, root) do local match = utils.get_at_path(maybe_match, capture_string) if match and filter_predicate(match) then @@ -220,31 +222,82 @@ end -- @param bufnr the buffer -- @param query_group the query file to use -- @param root the root node -function M.iter_group_results(bufnr, query_group, root) - local lang = parsers.get_buf_lang(bufnr) - if not lang then return function() end end +-- @param root the root node lang, if known +function M.iter_group_results(bufnr, query_group, root, root_lang) + local buf_lang = parsers.get_buf_lang(bufnr) + + if not buf_lang then return EMPTY_ITER end + + local parser = parsers.get_parser(bufnr, buf_lang) + if not parser then return EMPTY_ITER end - local query = M.get_query(lang, query_group) - if not query then return function() end end + if not root then + local first_tree = parser:trees()[1] + + if first_tree then + root = first_tree:root() + end + end + + if not root then return EMPTY_ITER end + + local range = {root:range()} + + if not root_lang then + local lang_tree = parser:language_for_range(range) + + if lang_tree then + root_lang = lang_tree:lang() + end + end - local parser = parsers.get_parser(bufnr, lang) - if not parser then return function() end end + if not root_lang then return EMPTY_ITER end - local root = root or parser:parse()[1]:root() - local start_row, _, end_row, _ = root:range() + local query = M.get_query(root_lang, query_group) + if not query then return EMPTY_ITER end -- The end row is exclusive so we need to add 1 to it. - return M.iter_prepared_matches(query, root, bufnr, start_row, end_row + 1) + return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1) end -function M.collect_group_results(bufnr, query_group, root) +function M.collect_group_results(bufnr, query_group, root, lang) local matches = {} - for prepared_match in M.iter_group_results(bufnr, query_group, root) do + for prepared_match in M.iter_group_results(bufnr, query_group, root, lang) do table.insert(matches, prepared_match) end return matches end +--- Same as get_capture_matches except this will recursively get matches for every language in the tree. +-- @param bufnr The bufnr +-- @param capture_or_fn The capture to get. If a function is provided then that +-- function will be used to resolve both the capture and query argument. +-- The function can return `nil` to ignore that tree. +-- @param query_type The query to get the capture from. This is ignore if a function is provided +-- for the captuer argument. +function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type) + local type_fn = type(capture_or_fn) == 'function' + and capture_or_fn + or function() + return capture_or_fn, query_type + end + local parser = parsers.get_parser(bufnr) + local matches = {} + + if parser then + parser:for_each_tree(function(tree, lang_tree) + local lang = lang_tree:lang() + local capture, type_ = type_fn(lang, tree, lang_tree) + + if capture then + vim.list_extend(matches, M.get_capture_matches(bufnr, capture, type_, tree:root(), lang)) + end + end) + end + + return matches +end + return M |
