diff options
Diffstat (limited to 'lua/nvim-treesitter/ts_utils.lua')
| -rw-r--r-- | lua/nvim-treesitter/ts_utils.lua | 63 |
1 files changed, 41 insertions, 22 deletions
diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua index 9720be7b..06f92c88 100644 --- a/lua/nvim-treesitter/ts_utils.lua +++ b/lua/nvim-treesitter/ts_utils.lua @@ -223,29 +223,33 @@ function M.find_definition(node, bufnr) local node_text = M.get_node_text(node)[1] local current_scope = M.containing_scope(node) local _, _, node_start = node:start() + local matching_def_nodes = {} -- If a scope wasn't found then use the root node if current_scope == node then current_scope = parsers.get_parser(bufnr).tree:root() end - - while current_scope ~= nil and current_scope ~= node do - for _, def in ipairs(locals.collect_locals(bufnr, current_scope)) do - if def.definition then - for _, def_node in ipairs(M.get_local_nodes(def.definition)) do - local _, _, def_start = def_node:start() - - if M.get_node_text(def_node)[1] == node_text and def_start < node_start then - return def_node, current_scope - end - end + + -- Get all definitions that match the node text + for _, def in ipairs(locals.get_definitions(bufnr)) do + for _, def_node in ipairs(M.get_local_nodes(def)) do + if M.get_node_text(def_node)[1] == node_text then + table.insert(matching_def_nodes, def_node) end end + end + -- Continue up each scope until we find the scope that contains the definition + while current_scope do + for _, def_node in ipairs(matching_def_nodes) do + if M.is_parent(current_scope, def_node) then + return def_node, current_scope + end + end current_scope = M.containing_scope(current_scope:parent()) end - return nil, nil + return node, parsers.get_parser(bufnr).tree:root() end -- Gets all nodes from a local list result. @@ -285,26 +289,41 @@ function M.recurse_local_nodes(local_def, accumulator, full_match, last_match) end end --- Finds usages of a node in a particula scope +-- Finds usages of a node in a given scope -- @param node the node to find usages for -- @param scope_node the node to look within -- @returns a list of nodes -function M.find_usages(node, scope_node) - local usages = {} +function M.find_usages(node, scope_node, bufnr) + local bufnr = bufnr or api.nvim_get_current_buf() local node_text = M.get_node_text(node)[1] if not node_text or #node_text < 1 then return {} end - for _, def in ipairs(locals.collect_locals(bufnr, scope_node)) do - if def.reference - and def.reference.node - and M.get_node_text(def.reference.node)[1] == node_text then - - table.insert(usages, def.reference.node) + local scope_node = scope_node or parsers.get_parser(bufnr).tree:root() + local references = locals.get_references(bufnr) + local usages = {} + + M.recurse_tree(scope_node, function(iter_node, _, next) + if vim.tbl_contains(references, iter_node) and M.get_node_text(iter_node)[1] == node_text then + table.insert(usages, iter_node) end - end + next() + end) return usages end +-- Recurses all child nodes of a tree. +-- The callback is provided the child node, parent_node, and a callback to recurse into +-- the child node. This allows for the ability to short circuit the recursion +-- if we found what we are looking for, we can then stop the recursion or skip a node +-- if need be. +-- @param tree the node root +-- @param cb the callback for each node +function M.recurse_tree(tree, cb) + for _, child in ipairs(M.get_named_children(tree)) do + cb(child, tree, function(next_node) M.recurse_tree(next_node or child, cb) end) + end +end + return M |
