diff options
| author | Stephan Seitz <stephan.seitz@fau.de> | 2020-07-10 22:17:51 +0200 |
|---|---|---|
| committer | Thomas Vigouroux <39092278+vigoux@users.noreply.github.com> | 2020-07-16 09:34:31 +0200 |
| commit | 8cf2dc7f9ad31c7467d28f90aec920018e240b7f (patch) | |
| tree | 3cfa3df9bc6aef2006b2bd6ec0fed530ffece02c /lua/nvim-treesitter/locals.lua | |
| parent | a4e2692c7b9fb562eca39ce0bb10ec2544bc7ccb (diff) | |
Refactor locals.lua:
- shared query group stuff -> query.lua
- local-specific stuff from ts_utils -> locals.lua
Diffstat (limited to 'lua/nvim-treesitter/locals.lua')
| -rw-r--r-- | lua/nvim-treesitter/locals.lua | 246 |
1 files changed, 176 insertions, 70 deletions
diff --git a/lua/nvim-treesitter/locals.lua b/lua/nvim-treesitter/locals.lua index fe99cd81..947187d5 100644 --- a/lua/nvim-treesitter/locals.lua +++ b/lua/nvim-treesitter/locals.lua @@ -1,73 +1,27 @@ -- Functions to handle locals -- Locals are a generalization of definition and scopes -- its the way nvim-treesitter uses to "understand" the code -local api = vim.api local queries = require'nvim-treesitter.query' local parsers = require'nvim-treesitter.parsers' -local utils = require'nvim-treesitter.utils' - -local default_dict = { - __index = function(table, key) - local exists = rawget(table, key) - if not exists then - table[key] = {} - end - return rawget(table, key) - end -} - -local query_cache = {} -setmetatable(query_cache, default_dict) +local ts_utils = require'nvim-treesitter.ts_utils' +local api = vim.api local M = {} -function M.collect_locals(bufnr, query_kind) - local locals = {} - - for prepared_match in M.iter_locals(bufnr, nil, query_kind) do - table.insert(locals, prepared_match) - end - - return locals -end - -local function update_cached_locals(bufnr, changed_tick, query_kind) - query_cache[query_kind][bufnr] = {tick=changed_tick, cache=( M.collect_locals(bufnr, query_kind) or {} )} +function M.collect_locals(bufnr) + return queries.collect_group_results(bufnr, 'locals') end -- Iterates matches from a locals query file. -- @param bufnr the buffer -- @param root the root node --- @param query_kind the query file to use -function M.iter_locals(bufnr, root, query_kind) - query_kind = query_kind or 'locals' - - local lang = parsers.get_buf_lang(bufnr) - if not lang then return end - - local query = queries.get_query(lang, query_kind) - if not query then return end - - local parser = parsers.get_parser(bufnr, lang) - if not parser then return end - - local root = root or parser:parse():root() - local start_row, _, end_row, _ = root:range() - - return queries.iter_prepared_matches(query, root, bufnr, start_row, end_row) +function M.iter_locals(bufnr, root) + return queries.iter_group_results(bufnr, 'locals', root) end -function M.get_locals(bufnr, query_kind) - query_kind = query_kind or 'locals' - - local bufnr = bufnr or api.nvim_get_current_buf() - local cached_local = query_cache[query_kind][bufnr] - if not cached_local or api.nvim_buf_get_changedtick(bufnr) > cached_local.tick then - update_cached_locals(bufnr,api.nvim_buf_get_changedtick(bufnr), query_kind) - end - - return query_cache[query_kind][bufnr].cache +function M.get_locals(bufnr) + return queries.get_matches(bufnr, 'locals') end function M.get_definitions(bufnr) @@ -112,27 +66,179 @@ function M.get_references(bufnr) return refs end ---- Return all nodes in locals corresponding to a specific capture (like @scope, @reference) --- Works like M.get_references or M.get_scopes except you can choose the capture --- Can also be a nested capture like @definition.function to get all nodes defining a function -function M.get_capture_matches(bufnr, capture_string, query_kind) - if not string.sub(capture_string, 1,2) == '@' then - print('capture_string must start with "@"') - return - end +-- Finds the definition node and it's scope node of a node +-- @param node starting node +-- @param bufnr buffer +-- @returns the definition node and the definition nodes scope node +function M.find_definition(node, bufnr) + local bufnr = bufnr or api.nvim_get_current_buf() + local node_text = ts_utils.get_node_text(node)[1] + local current_scope = M.containing_scope(node) + local matching_def_nodes = {} - --remove leading "@" - capture_string = string.sub(capture_string, 2) + -- If a scope wasn't found then use the root node + if current_scope == node then + current_scope = parsers.get_parser(bufnr).tree:root() + end - local matches = {} - for _, match in pairs(M.get_locals(bufnr, query_kind)) do - local insert = utils.get_at_path(match, capture_string) + -- Get all definitions that match the node text + for _, def in ipairs(M.get_definitions(bufnr)) do + for _, def_node in ipairs(M.get_local_nodes(def)) do + if ts_utils.get_node_text(def_node)[1] == node_text then + table.insert(matching_def_nodes, def_node) + end + end + end - if insert then - table.insert(matches, insert) + -- Continue up each scope until we find the scope that contains the definition + while current_scope do + for _, def_node in ipairs(matching_def_nodes) do + if ts_utils.is_parent(current_scope, def_node) then + return def_node, current_scope end end - return matches + current_scope = M.containing_scope(current_scope:parent()) + end + + return node, parsers.get_parser(bufnr).tree:root() +end + +-- Gets all nodes from a local list result. +-- @param local_def the local list result +-- @returns a list of nodes +function M.get_local_nodes(local_def) + local result = {} + + M.recurse_local_nodes(local_def, function(_, node) + table.insert(result, node) + end) + + return result +end + +-- Recurse locals results until a node is found. +-- The accumulator function is given +-- * The table of the node +-- * The node +-- * The full definition match `@definition.var.something` -> 'var.something' +-- * The last definition match `@definition.var.something` -> 'something' +-- @param The locals result +-- @param The accumulator function +-- @param The full match path to append to +-- @param The last match +function M.recurse_local_nodes(local_def, accumulator, full_match, last_match) + if local_def.node then + accumulator(local_def, local_def.node, full_match, last_match) + else + for match_key, def in pairs(local_def) do + M.recurse_local_nodes( + def, + accumulator, + full_match and (full_match..'.'..match_key) or match_key, + match_key) + end + end +end + +-- Finds usages of a node in a given scope +-- @param node the node to find usages for +-- @param scope_node the node to look within +-- @returns a list of nodes +function M.find_usages(node, scope_node, bufnr) + local bufnr = bufnr or api.nvim_get_current_buf() + local node_text = ts_utils.get_node_text(node)[1] + + if not node_text or #node_text < 1 then return {} end + + local scope_node = scope_node or parsers.get_parser(bufnr).tree:root() + local usages = {} + + for match in M.iter_locals(bufnr, scope_node) do + if match.reference + and match.reference.node + and ts_utils.get_node_text(match.reference.node)[1] == node_text + then + table.insert(usages, match.reference.node) + end + end + + return usages +end + +function M.containing_scope(node, bufnr) + local bufnr = bufnr or api.nvim_get_current_buf() + + local scopes = M.get_scopes(bufnr) + if not node or not scopes then return end + + local iter_node = node + + while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do + iter_node = iter_node:parent() + end + + return iter_node or node +end + +function M.nested_scope(node, cursor_pos) + local bufnr = api.nvim_get_current_buf() + + local scopes = M.get_scopes(bufnr) + if not node or not scopes then return end + + local row = cursor_pos.row + local col = cursor_pos.col + local scope = M.containing_scope(node) + + for _, child in ipairs(ts_utils.get_named_children(scope)) do + local row_, col_ = child:start() + if vim.tbl_contains(scopes, child) and ((row_+1 == row and col_ > col) or row_+1 > row) then + return child + end + end +end + +function M.next_scope(node) + local bufnr = api.nvim_get_current_buf() + + local scopes = M.get_scopes(bufnr) + if not node or not scopes then return end + + local scope = M.containing_scope(node) + + local parent = scope:parent() + if not parent then return end + + local is_prev = true + for _, child in ipairs(ts_utils.get_named_children(parent)) do + if child == scope then + is_prev = false + elseif not is_prev and vim.tbl_contains(scopes, child) then + return child + end + end +end + +function M.previous_scope(node) + local bufnr = api.nvim_get_current_buf() + + local scopes = M.get_scopes(bufnr) + if not node or not scopes then return end + + local scope = M.containing_scope(node) + + local parent = scope:parent() + if not parent then return end + + local is_prev = true + local children = ts_utils.get_named_children(parent) + for i=#children,1,-1 do + if children[i] == scope then + is_prev = false + elseif not is_prev and vim.tbl_contains(scopes, children[i]) then + return children[i] + end + end end return M |
