summaryrefslogtreecommitdiff
path: root/lua/nvim-treesitter/ts_utils.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/nvim-treesitter/ts_utils.lua')
-rw-r--r--lua/nvim-treesitter/ts_utils.lua87
1 files changed, 71 insertions, 16 deletions
diff --git a/lua/nvim-treesitter/ts_utils.lua b/lua/nvim-treesitter/ts_utils.lua
index c5475d82..1f7f2093 100644
--- a/lua/nvim-treesitter/ts_utils.lua
+++ b/lua/nvim-treesitter/ts_utils.lua
@@ -114,10 +114,46 @@ function M.get_named_children(node)
end
function M.get_node_at_cursor(winnr)
- if not parsers.has_parser() then return end
local cursor = api.nvim_win_get_cursor(winnr or 0)
- local root = parsers.get_parser():parse()[1]:root()
- return root:named_descendant_for_range(cursor[1]-1,cursor[2],cursor[1]-1,cursor[2])
+ local cursor_range = { cursor[1] - 1, cursor[2] }
+ local root = M.get_root_for_position(unpack(cursor_range))
+
+ if not root then return end
+
+ return root:named_descendant_for_range(cursor_range[1], cursor_range[2], cursor_range[1], cursor_range[2])
+end
+
+function M.get_root_for_position(line, col, root_lang_tree)
+ if not root_lang_tree then
+ if not parsers.has_parser() then return end
+
+ root_lang_tree = parsers.get_parser()
+ end
+
+ local lang_tree = root_lang_tree:language_for_range({ line, col, line, col })
+
+ for _, tree in ipairs(lang_tree:trees()) do
+ local root = tree:root()
+
+ if root and M.is_in_node_range(root, line, col) then
+ return root, tree, lang_tree
+ end
+ end
+
+ -- This isn't a likely scenario, since the position must belong to a tree somewhere.
+ return nil, nil, lang_tree
+end
+
+function M.get_root_for_node(node)
+ local parent = node
+ local result = node
+
+ while parent ~= nil do
+ result = parent
+ parent = result:parent()
+ end
+
+ return result
end
function M.highlight_node(node, buf, hl_namespace, hl_group)
@@ -213,25 +249,44 @@ end
--- Memoizes a function based on the buffer tick of the provided bufnr.
-- The cache entry is cleared when the buffer is detached to avoid memory leaks.
-- @param fn: the fn to memoize, taking the bufnr as first argument
+-- @param options:
+-- - bufnr: extracts a bufnr from the given arguments.
+-- - key: extracts the cache key from the given arguments.
-- @returns a memoized function
-function M.memoize_by_buf_tick(fn)
+function M.memoize_by_buf_tick(fn, options)
+ options = options or {}
+
local cache = {}
+ local bufnr_fn = utils.to_func(options.bufnr or utils.identity)
+ local key_fn = utils.to_func(options.key or utils.identity)
- return function(bufnr)
- if cache[bufnr] then
- return cache[bufnr]
+ return function(...)
+ local bufnr = bufnr_fn(...)
+ local key = key_fn(...)
+ local tick = api.nvim_buf_get_changedtick(bufnr)
+
+ if cache[key] then
+ if cache[key].last_tick == tick then
+ return cache[key].result
+ end
else
- cache[bufnr] = {}
- api.nvim_buf_attach(bufnr, false,
- {
- on_changedtick = function() cache[bufnr] = fn(bufnr) end,
- on_detach = function() cache[bufnr] = nil end
- }
- )
+ local function detach_handler()
+ cache[key] = nil
+ end
+
+ -- Clean up logic only!
+ api.nvim_buf_attach(bufnr, false, {
+ on_detach = detach_handler,
+ on_reload = detach_handler
+ })
end
- cache[bufnr] = fn(bufnr)
- return cache[bufnr]
+ cache[key] = {
+ result = fn(...),
+ last_tick = tick
+ }
+
+ return cache[key].result
end
end