summaryrefslogtreecommitdiff
path: root/lua/nvim-treesitter/query_predicates.lua
blob: e010e6dd83e295289ba52f1646f04c66aa281af8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
local utils = require'nvim-treesitter.utils'
local ts_utils = require'nvim-treesitter.ts_utils'

local M = {}

local function get_nth_child(node, n)
  if node:named_child_count() > n then
    return node:named_child(n)
  end
end

local function get_node(query, match, pred_item)
  return utils.get_at_path(match, query.captures[pred_item]..'.node')
end

local function create_adjacent_predicate(match_successive_nodes)
  return function(query, match, pred)
    if #pred < 3 then error("adjacent? must have at least two arguments!") end
    local node = get_node(query, match, pred[2])
    if not node then return true end

    local adjacent_types = {unpack(pred, 3)}
    local adjacent_node = ts_utils.get_next_node(node)

    if match_successive_nodes then
      -- Move to the last node in a series that doesn't match the node type
      -- and use that node to compare with.
      while adjacent_node and adjacent_node:type() == node:type() do
        node = adjacent_node
        adjacent_node = ts_utils.get_next_node(node)
      end
    end

    if not adjacent_node then return false end

    for _, adjacent_type in ipairs(adjacent_types) do
      if type(adjacent_type) == "number" then
        if get_node(query, match, adjacent_type) == adjacent_node then
          return true
        end
      elseif type(adjacent_type) == "string" then
        if adjacent_node:type() == adjacent_type then
          return true
        end
      end
    end

    return false
  end
end

function M.check_predicate(query, match, pred)
  local check_function = M[pred[1]]
  if check_function then
    return check_function(query, match, pred)
  else
    return true
  end
end

function M.check_negated_predicate(query, match, pred)
  local check_function = M[string.sub(pred[1], #"not-" + 1)]
  if check_function then
    return not check_function(query, match, pred)
  else
    return true
  end
end

M['first?'] = function (query, match, pred)
  if #pred ~= 2 then error("first? must have exactly one argument!") end
  local node = get_node(query, match, pred[2])
  if node and node:parent() then
    return get_nth_child(node:parent(), 0) == node
  end
end

M['last?'] = function (query, match, pred)
  if #pred ~= 2 then error("first? must have exactly one argument!") end
  local node = get_node(query, match, pred[2])
  if node and node:parent() then
    local num_children = node:parent():named_child_count()
    return get_nth_child(node:parent(), num_children - 1) == node
  end
end

 M['nth?'] = function(query, match, pred)
  if #pred ~= 3 then error("nth? must have exactly two arguments!") end
  local node = get_node(query, match, pred[2])
  if node and node:parent() then
    return get_nth_child(node:parent(), pred[3] - 1) == node
  end
end

M['has-ancestor?'] = function(query, match, pred)
  if #pred ~= 3 then error("has-ancestor? must have exactly two arguments!") end
  local node = get_node(query, match, pred[2])
  local ancestor_type = pred[3]
  if not node then return true end

  node = node:parent()
  while node do
    if node:type() == ancestor_type then
      return true
    end
    node = node:parent()
  end
  return false
end

M['adjacent?'] = create_adjacent_predicate(false)
M['adjacent-block?'] = create_adjacent_predicate(true)

return M