summaryrefslogtreecommitdiff
path: root/lua/nvim-treesitter/query_predicates.lua
diff options
context:
space:
mode:
authorStephan Seitz <stephan.seitz@fau.de>2020-07-20 17:48:58 +0200
committerThomas Vigouroux <39092278+vigoux@users.noreply.github.com>2020-07-27 10:15:33 +0200
commit5462fc92cbb6e94d93a7e20d15f81f68d918f71d (patch)
treebad50530f61969a2f2691bcd487b4343a2925aa6 /lua/nvim-treesitter/query_predicates.lua
parent6f01384cb2d60db4b5990085422711c0a764ed7f (diff)
Add predicates module
Diffstat (limited to 'lua/nvim-treesitter/query_predicates.lua')
-rw-r--r--lua/nvim-treesitter/query_predicates.lua58
1 files changed, 58 insertions, 0 deletions
diff --git a/lua/nvim-treesitter/query_predicates.lua b/lua/nvim-treesitter/query_predicates.lua
new file mode 100644
index 00000000..00daa50d
--- /dev/null
+++ b/lua/nvim-treesitter/query_predicates.lua
@@ -0,0 +1,58 @@
+local utils = require'nvim-treesitter.utils'
+
+local M = {}
+
+local function get_nth_child(node, n)
+ if node:named_child_count() > n then
+ return node:named_child(n)
+ end
+end
+
+local function get_node(query, match, pred_item)
+ return utils.get_at_path(match, query.captures[pred_item]..'.node')
+end
+
+function M.check_predicate(query, match, pred)
+ local check_function = M[string.gsub('check_'..pred[1], "%?$", '')]
+ if check_function then
+ return check_function(query, match, pred)
+ else
+ return true
+ end
+end
+
+function M.check_negated_predicate(query, match, pred)
+ local check_function = M[string.gsub('check_'..string.sub(pred[1], #"not-" + 1), "%?$", '')]
+ if check_function then
+ return not check_function(query, match, pred)
+ else
+ return true
+ end
+end
+
+function M.check_first(query, match, pred)
+ if #pred ~= 2 then error("first? must have exactly one argument!") end
+ local node = get_node(query, match, pred[2])
+ if node and node:parent() then
+ return get_nth_child(node:parent(), 0) == node
+ end
+end
+
+function M.check_last(query, match, pred)
+ if #pred ~= 2 then error("first? must have exactly one argument!") end
+ local node = get_node(query, match, pred[2])
+ if node and node:parent() then
+ local num_children = node:parent():named_child_count()
+ return get_nth_child(node:parent(), num_children - 1) == node
+ end
+end
+
+function M.check_nth(query, match, pred)
+ if #pred ~= 3 then error("nth? must have exactly two arguments!") end
+ local node = get_node(query, match, pred[2])
+ if node and node:parent() then
+ return get_nth_child(node:parent(), pred[3] - 1) == node
+ end
+end
+
+return M