diff options
Diffstat (limited to 'lua/nvim-treesitter/query.lua')
| -rw-r--r-- | lua/nvim-treesitter/query.lua | 120 |
1 files changed, 80 insertions, 40 deletions
diff --git a/lua/nvim-treesitter/query.lua b/lua/nvim-treesitter/query.lua index 0d2e8cb3..7009e9f2 100644 --- a/lua/nvim-treesitter/query.lua +++ b/lua/nvim-treesitter/query.lua @@ -128,6 +128,59 @@ function M.invalidate_query_file(fname) M.invalidate_query_cache(fnamemodify(fname, ":p:h:t"), fnamemodify(fname, ":t:r")) end +local function prepare_query(bufnr, query_name, root, root_lang) + local buf_lang = parsers.get_buf_lang(bufnr) + + if not buf_lang then + return + end + + local parser = parsers.get_parser(bufnr, buf_lang) + if not parser then + return + end + + if not root then + local first_tree = parser:trees()[1] + + if first_tree then + root = first_tree:root() + end + end + + if not root then + return + end + + local range = { root:range() } + + if not root_lang then + local lang_tree = parser:language_for_range(range) + + if lang_tree then + root_lang = lang_tree:lang() + end + end + + if not root_lang then + return + end + + local query = M.get_query(root_lang, query_name) + if not query then + return + end + + return query, + { + root = root, + source = bufnr, + start = range[1], + -- The end row is exclusive so we need to add 1 to it. + stop = range[3] + 1, + } +end + function M.iter_prepared_matches(query, qnode, bufnr, start_row, end_row) -- A function that splits a string on '.' local function split(string) @@ -229,6 +282,31 @@ function M.get_capture_matches(bufnr, captures, query_group, root, lang) return matches end +function M.iter_captures(bufnr, query_name, root, lang) + local query, params = prepare_query(bufnr, query_name, root, lang) + if not query then + return EMPTY_ITER + end + + local iter = query:iter_captures(params.root, params.source, params.start, params.stop) + + local function wrapped_iter() + local id, node, metadata = iter() + if not id then + return + end + + local name = query.captures[id] + if string.sub(name, 1, 1) == "_" then + return wrapped_iter() + end + + return name, node, metadata + end + + return wrapped_iter +end + function M.find_best_match(bufnr, capture_string, query_group, filter_predicate, scoring_function, root) if string.sub(capture_string, 1, 1) == "@" then --remove leading "@" @@ -262,50 +340,12 @@ end -- @param root the root node -- @param root the root node lang, if known function M.iter_group_results(bufnr, query_group, root, root_lang) - local buf_lang = parsers.get_buf_lang(bufnr) - - if not buf_lang then - return EMPTY_ITER - end - - local parser = parsers.get_parser(bufnr, buf_lang) - if not parser then - return EMPTY_ITER - end - - if not root then - local first_tree = parser:trees()[1] - - if first_tree then - root = first_tree:root() - end - end - - if not root then - return EMPTY_ITER - end - - local range = { root:range() } - - if not root_lang then - local lang_tree = parser:language_for_range(range) - - if lang_tree then - root_lang = lang_tree:lang() - end - end - - if not root_lang then - return EMPTY_ITER - end - - local query = M.get_query(root_lang, query_group) + local query, params = prepare_query(bufnr, query_group, root, root_lang) if not query then return EMPTY_ITER end - -- The end row is exclusive so we need to add 1 to it. - return M.iter_prepared_matches(query, root, bufnr, range[1], range[3] + 1) + return M.iter_prepared_matches(query, params.root, params.source, params.start, params.stop) end function M.collect_group_results(bufnr, query_group, root, lang) |
