summaryrefslogtreecommitdiff
path: root/lua/nvim-treesitter/indent.lua
blob: ebf7695de42906d97539f2cc6163e823c9a8f116 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
local parsers = require "nvim-treesitter.parsers"
local queries = require "nvim-treesitter.query"
local tsutils = require "nvim-treesitter.ts_utils"

local M = {}

-- TODO(kiyan): move this in tsutils and document it
local function get_node_at_line(root, lnum)
  for node in root:iter_children() do
    local srow, _, erow = node:range()
    if srow == lnum then
      return node
    end

    if node:child_count() > 0 and srow < lnum and lnum <= erow then
      return get_node_at_line(node, lnum)
    end
  end
end

local function node_fmt(node)
  if not node then
    return nil
  end
  return tostring(node)
end

local get_indents = tsutils.memoize_by_buf_tick(function(bufnr, root, lang)
  local get_map = function(capture)
    local matches = queries.get_capture_matches(bufnr, capture, "indents", root, lang) or {}
    local map = {}
    for _, node in ipairs(matches) do
      map[tostring(node)] = true
    end
    return map
  end

  return {
    indents = get_map "@indent.node",
    branches = get_map "@branch.node",
    returns = get_map "@return.node",
    ignores = get_map "@ignore.node",
  }
end, {
  -- Memoize by bufnr and lang together.
  key = function(bufnr, _, lang)
    return tostring(bufnr) .. "_" .. lang
  end,
})

function M.get_indent(lnum)
  local parser = parsers.get_parser()
  if not parser or not lnum then
    return -1
  end

  -- get_root_for_position is 0-based.
  local root, _, lang_tree = tsutils.get_root_for_position(lnum - 1, 0, parser)

  -- Not likely, but just in case...
  if not root then
    return 0
  end

  local q = get_indents(vim.api.nvim_get_current_buf(), root, lang_tree:lang())
  local node = get_node_at_line(root, lnum - 1)

  local indent = 0
  local indent_size = vim.fn.shiftwidth()

  -- to get correct indentation when we land on an empty line (for instance by typing `o`), we try
  -- to use indentation of previous nonblank line, this solves the issue also for languages that
  -- do not use @branch after blocks (e.g. Python)
  if not node then
    local prevnonblank = vim.fn.prevnonblank(lnum)
    if prevnonblank ~= lnum then
      local prev_node = get_node_at_line(root, prevnonblank - 1)
      -- get previous node in any case to avoid erroring
      while not prev_node and prevnonblank - 1 > 0 do
        prevnonblank = vim.fn.prevnonblank(prevnonblank - 1)
        prev_node = get_node_at_line(root, prevnonblank - 1)
      end

      -- nodes can be marked @return to prevent using them
      if prev_node and not q.returns[node_fmt(prev_node)] then
        local row = prev_node:start()
        local end_row = prev_node:end_()

        -- if the previous node is being constructed (like function() `o` in lua), or line is inside the node
        -- we indent one more from the start of node, else we indent default
        -- NOTE: this doesn't work for python which behave strangely
        if prev_node:has_error() or lnum <= end_row then
          return vim.fn.indent(row + 1) + indent_size
        end
        return vim.fn.indent(row + 1)
      end
    end
  end

  -- if the prevnonblank fails (prev_node wraps our line) we need to fall back to taking
  -- the first child of the node that wraps the current line, or the wrapper itself
  if not node then
    local wrapper = root:descendant_for_range(lnum - 1, 0, lnum - 1, -1)
    node = wrapper:child(0) or wrapper
    if q.indents[node_fmt(wrapper)] ~= nil and wrapper ~= root then
      indent = indent_size
    end
  end

  while node and q.branches[node_fmt(node)] do
    node = node:parent()
  end

  local first = true
  local prev_row = node:start()

  while node do
    -- Do not indent if we are inside an @ignore block.
    -- If a node spans from L1,C1 to L2,C2, we know that lines where L1 < line <= L2 would
    -- have their indentations contained by the node.
    if q.ignores[node_fmt(node)] and node:start() < lnum - 1 and lnum - 1 <= node:end_() then
      return -1
    end

    -- do not indent the starting node, do not add multiple indent levels on single line
    local row = node:start()
    if not first and q.indents[node_fmt(node)] and prev_row ~= row then
      indent = indent + indent_size
      prev_row = row
    end

    node = node:parent()
    first = false
  end

  return indent
end

local indent_funcs = {}

function M.attach(bufnr)
  indent_funcs[bufnr] = vim.bo.indentexpr
  vim.bo.indentexpr = "nvim_treesitter#indent()"
  vim.api.nvim_command("au Filetype " .. vim.bo.filetype .. " setlocal indentexpr=nvim_treesitter#indent()")
end

function M.detach(bufnr)
  vim.bo.indentexpr = indent_funcs[bufnr]
end

return M