From 0948bf22745f1d80572e2b46ed547c7c8674cca9 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Mon, 31 Oct 2022 10:52:52 +0000 Subject: feat(emmy): add emmylua annotations --- lua/nvim-treesitter/query.lua | 99 ++++++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 29 deletions(-) (limited to 'lua/nvim-treesitter/query.lua') diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index aeaa683b..bc80d51e 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -11,8 +11,10 @@ local EMPTY_ITER = function() end M.built_in_query_groups = { "highlights", "locals", "folds", "indents", "injections" } --- Creates a function that checks whether a given query exists --- for a specific language. +--- Creates a function that checks whether a given query exists +--- for a specific language. +---@param query string +---@return function(string): boolean local function get_query_guard(query) return function(lang) return M.has_query_files(lang, query) @@ -23,6 +25,7 @@ for _, query in ipairs(M.built_in_query_groups) do M["has_" .. query] = get_query_guard(query) end +---@return string[] function M.available_query_groups() local query_files = api.nvim_get_runtime_file("queries/*/*.scm", true) local groups = {} @@ -57,11 +60,19 @@ do end end +---@param lang string +---@param query_name string +---@return string[] local function runtime_queries(lang, query_name) return api.nvim_get_runtime_file(string.format("queries/%s/%s.scm", lang, query_name), true) or {} end +---@type table> local query_files_cache = {} + +---@param lang string +---@param query_name string +---@return boolean function M.has_query_files(lang, query_name) if not query_files_cache[lang] then query_files_cache[lang] = {} @@ -86,6 +97,8 @@ do local cache = setmetatable({}, mt) --- Same as `vim.treesitter.query` except will return cached values + ---@param lang string + ---@param query_name string function M.get_query(lang, query_name) if cache[lang][query_name] == nil then cache[lang][query_name] = tsq.get_query(lang, query_name) @@ -98,6 +111,8 @@ do --- If lang and query_name is both present, will reload for only the lang and query_name. --- If only lang is present, will reload all query_names for that lang --- If none are present, will reload everything + ---@param lang string + ---@param query_name string function M.invalidate_query_cache(lang, query_name) if lang and query_name then cache[lang][query_name] = nil @@ -106,14 +121,14 @@ do end elseif lang and not query_name then query_files_cache[lang] = nil - for query_name, _ in pairs(cache[lang]) do - M.invalidate_query_cache(lang, query_name) + for query_name0, _ in pairs(cache[lang]) do + M.invalidate_query_cache(lang, query_name0) end elseif not lang and not query_name then query_files_cache = {} - for lang, _ in pairs(cache) do - for query_name, _ in pairs(cache[lang]) do - M.invalidate_query_cache(lang, query_name) + for lang0, _ in pairs(cache) do + for query_name0, _ in pairs(cache[lang0]) do + M.invalidate_query_cache(lang0, query_name0) end end else @@ -123,11 +138,23 @@ do end --- This function is meant for an autocommand and not to be used. Only use if file is a query file. +---@param fname string function M.invalidate_query_file(fname) local fnamemodify = vim.fn.fnamemodify M.invalidate_query_cache(fnamemodify(fname, ":p:h:t"), fnamemodify(fname, ":t:r")) end +---@class QueryInfo +---@field root LanguageTree +---@field source integer +---@field start integer +---@field stop integer + +---@param bufnr integer +---@param query_name string +---@param root LanguageTree +---@param root_lang string|nil +---@return Query|nil, QueryInfo|nil local function prepare_query(bufnr, query_name, root, root_lang) local buf_lang = parsers.get_buf_lang(bufnr) @@ -181,6 +208,10 @@ local function prepare_query(bufnr, query_name, root, root_lang) } end +---@param query Query +---@param bufnr integer +---@param start_row integer +---@param end_row integer function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) -- A function that splits a string on '.' local function split(string) @@ -249,15 +280,16 @@ function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) end --- Return all nodes corresponding to a specific capture path (like @definition.var, @reference.type) --- Works like M.get_references or M.get_scopes except you can choose the capture --- Can also be a nested capture like @definition.function to get all nodes defining a function. --- --- @param bufnr the buffer --- @param captures a single string or a list of strings --- @param query_group the name of query group (highlights or injections for example) --- @param root (optional) node from where to start the search --- @param lang (optional) the language from where to get the captures. --- Root nodes can have several languages. +---Works like M.get_references or M.get_scopes except you can choose the capture +---Can also be a nested capture like @definition.function to get all nodes defining a function. +--- +---@param bufnr integer the buffer +---@param captures string|string[] +---@param query_group string the name of query group (highlights or injections for example) +---@param root LanguageTree|nil node from where to start the search +---@param lang string|nil the language from where to get the captures. +--- Root nodes can have several languages. +---@return table|nil function M.get_capture_matches(bufnr, captures, query_group, root, lang) if type(captures) == "string" then captures = { captures } @@ -289,6 +321,7 @@ function M.iter_captures(bufnr, query_name, root, lang) if not query then return EMPTY_ITER end + assert(params) local iter = query:iter_captures(params.root, params.source, params.start, params.stop) @@ -336,16 +369,17 @@ function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, return best end --- Iterates matches from a query file. --- @param bufnr the buffer --- @param query_group the query file to use --- @param root the root node --- @param root the root node lang, if known +---Iterates matches from a query file. +---@param bufnr integer the buffer +---@param query_group string the query file to use +---@param root LanguageTree the root node +---@param root_lang string|nil the root node lang, if known function M.iter_group_results(bufnr, query_group, root, root_lang) local query, params = prepare_query(bufnr, query_group, root, root_lang) if not query then return EMPTY_ITER end + assert(params) return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop) end @@ -360,18 +394,25 @@ function M.collect_group_results(bufnr, query_group, root, lang) return matches end +---@alias CaptureResFn function(string, LanguageTree, LanguageTree): string, string + --- Same as get_capture_matches except this will recursively get matches for every language in the tree. --- @param bufnr The bufnr --- @param capture_or_fn The capture to get. If a function is provided then that --- function will be used to resolve both the capture and query argument. --- The function can return `nil` to ignore that tree. --- @param query_type The query to get the capture from. This is ignore if a function is provided --- for the captuer argument. +---@param bufnr integer The bufnr +---@param capture_or_fn string|CaptureResFn The capture to get. If a function is provided then that +--- function will be used to resolve both the capture and query argument. +--- The function can return `nil` to ignore that tree. +---@param query_type string The query to get the capture from. This is ignore if a function is provided +--- for the captuer argument. function M.get_capture_matches_recursively(bufnr, capture_or_fn, query_type) - local type_fn = type(capture_or_fn) == "function" and capture_or_fn - or function() + ---@type CaptureResFn + local type_fn + if type(capture_or_fn) == "function" then + type_fn = capture_or_fn + else + type_fn = function(_, _, _) return capture_or_fn, query_type end + end local parser = parsers.get_parser(bufnr) local matches = {} -- cgit v1.2.3