diff options
Diffstat (limited to 'lua/blink/cmp/sources/lib/tree.lua')
| -rw-r--r-- | lua/blink/cmp/sources/lib/tree.lua | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/lua/blink/cmp/sources/lib/tree.lua b/lua/blink/cmp/sources/lib/tree.lua new file mode 100644 index 0000000..c89d04d --- /dev/null +++ b/lua/blink/cmp/sources/lib/tree.lua @@ -0,0 +1,168 @@ +--- @class blink.cmp.SourceTreeNode +--- @field id string +--- @field source blink.cmp.SourceProvider +--- @field dependencies blink.cmp.SourceTreeNode[] +--- @field dependents blink.cmp.SourceTreeNode[] + +--- @class blink.cmp.SourceTree +--- @field nodes blink.cmp.SourceTreeNode[] +--- @field new fun(context: blink.cmp.Context, all_sources: blink.cmp.SourceProvider[]): blink.cmp.SourceTree +--- @field get_completions fun(self: blink.cmp.SourceTree, context: blink.cmp.Context, on_items_by_provider: fun(items_by_provider: table<string, blink.cmp.CompletionItem[]>)): blink.cmp.Task +--- @field emit_completions fun(self: blink.cmp.SourceTree, items_by_provider: table<string, blink.cmp.CompletionItem[]>, on_items_by_provider: fun(items_by_provider: table<string, blink.cmp.CompletionItem[]>)): nil +--- @field get_top_level_nodes fun(self: blink.cmp.SourceTree): blink.cmp.SourceTreeNode[] +--- @field detect_cycle fun(node: blink.cmp.SourceTreeNode, visited?: table<string, boolean>, path?: table<string, boolean>): boolean + +local utils = require('blink.cmp.lib.utils') +local async = require('blink.cmp.lib.async') + +--- @type blink.cmp.SourceTree +--- @diagnostic disable-next-line: missing-fields +local tree = {} + +--- @param context blink.cmp.Context +--- @param all_sources blink.cmp.SourceProvider[] +function tree.new(context, all_sources) + -- only include enabled sources for the given context + local sources = vim.tbl_filter( + function(source) return vim.tbl_contains(context.providers, source.id) and source:enabled(context) end, + all_sources + ) + local source_ids = vim.tbl_map(function(source) return source.id end, sources) + + -- create a node for each source + local nodes = vim.tbl_map( + function(source) return { id = source.id, source = source, dependencies = {}, dependents = {} } end, + sources + ) + + -- build the tree + for idx, source in ipairs(sources) do + local node = nodes[idx] + for _, fallback_source_id in ipairs(source.config.fallbacks(context, source_ids)) do + local fallback_node = nodes[utils.index_of(source_ids, fallback_source_id)] + if fallback_node ~= nil then + table.insert(node.dependents, fallback_node) + table.insert(fallback_node.dependencies, node) + end + end + end + + -- circular dependency check + for _, node in ipairs(nodes) do + tree.detect_cycle(node) + end + + return setmetatable({ nodes = nodes }, { __index = tree }) +end + +function tree:get_completions(context, on_items_by_provider) + local should_push_upstream = false + local items_by_provider = {} + local is_all_cached = true + local nodes_falling_back = {} + + --- @param node blink.cmp.SourceTreeNode + local function get_completions_for_node(node) + -- check that all the dependencies have been triggered, and are falling back + for _, dependency in ipairs(node.dependencies) do + if not nodes_falling_back[dependency.id] then return async.task.empty() end + end + + return async.task.new(function(resolve, reject) + return node.source:get_completions(context, function(items, is_cached) + items_by_provider[node.id] = items + is_all_cached = is_all_cached and is_cached + + if should_push_upstream then self:emit_completions(items_by_provider, on_items_by_provider) end + if #items ~= 0 then return resolve() end + + -- run dependents if the source returned 0 items + nodes_falling_back[node.id] = true + local tasks = vim.tbl_map(function(dependent) return get_completions_for_node(dependent) end, node.dependents) + async.task.await_all(tasks):map(resolve):catch(reject) + end) + end) + end + + -- run the top level nodes and let them fall back to their dependents if needed + local tasks = vim.tbl_map(function(node) return get_completions_for_node(node) end, self:get_top_level_nodes()) + return async.task + .await_all(tasks) + :map(function() + should_push_upstream = true + + -- if atleast one of the results wasn't cached, emit the results + if not is_all_cached then self:emit_completions(items_by_provider, on_items_by_provider) end + end) + :catch(function(err) vim.print('failed to get completions with error: ' .. err) end) +end + +function tree:emit_completions(items_by_provider, on_items_by_provider) + local nodes_falling_back = {} + local final_items_by_provider = {} + + local add_node_items + add_node_items = function(node) + for _, dependency in ipairs(node.dependencies) do + if not nodes_falling_back[dependency.id] then return end + end + local items = items_by_provider[node.id] + if items ~= nil and #items > 0 then + final_items_by_provider[node.id] = items + else + nodes_falling_back[node.id] = true + for _, dependent in ipairs(node.dependents) do + add_node_items(dependent) + end + end + end + + for _, node in ipairs(self:get_top_level_nodes()) do + add_node_items(node) + end + + on_items_by_provider(final_items_by_provider) +end + +--- Internal --- + +function tree:get_top_level_nodes() + local top_level_nodes = {} + for _, node in ipairs(self.nodes) do + if #node.dependencies == 0 then table.insert(top_level_nodes, node) end + end + return top_level_nodes +end + +--- Helper function to detect cycles using DFS +--- @param node blink.cmp.SourceTreeNode +--- @param visited? table<string, boolean> +--- @param path? table<string, boolean> +--- @return boolean +function tree.detect_cycle(node, visited, path) + visited = visited or {} + path = path or {} + + if path[node.id] then + -- Found a cycle - construct the cycle path for error message + local cycle = { node.id } + for id, _ in pairs(path) do + table.insert(cycle, id) + end + error('Circular dependency detected: ' .. table.concat(cycle, ' -> ')) + end + + if visited[node.id] then return false end + + visited[node.id] = true + path[node.id] = true + + for _, dependent in ipairs(node.dependents) do + if tree.detect_cycle(dependent, visited, path) then return true end + end + + path[node.id] = nil + return false +end + +return tree |
