summaryrefslogtreecommitdiff
path: root/lua
diff options
context:
space:
mode:
authorTJ DeVries <devries.timothyj@gmail.com>2020-11-16 10:58:30 -0500
committerGitHub <noreply@github.com>2020-11-16 10:58:30 -0500
commit985856946e30a7d93eb3b8aac6b5b5d7d589a768 (patch)
tree0e334d5dd1808f6508c869322699825d2b576130 /lua
parentad7280e0b99ecd9f78ac6c70b4e5b49ed5e632f8 (diff)
feat: Allow overriding actions from mappings (#248)
Diffstat (limited to 'lua')
-rw-r--r--lua/telescope/actions/init.lua (renamed from lua/telescope/actions.lua)64
-rw-r--r--lua/telescope/actions/mt.lua96
-rw-r--r--lua/telescope/builtin.lua21
-rw-r--r--lua/telescope/pickers.lua5
-rw-r--r--lua/telescope/state.lua9
-rw-r--r--lua/tests/automated/action_spec.lua164
6 files changed, 301 insertions, 58 deletions
diff --git a/lua/telescope/actions.lua b/lua/telescope/actions/init.lua
index a35ffee..4bc230a 100644
--- a/lua/telescope/actions.lua
+++ b/lua/telescope/actions/init.lua
@@ -6,43 +6,14 @@ local log = require('telescope.log')
local path = require('telescope.path')
local state = require('telescope.state')
+local transform_mod = require('telescope.actions.mt').transform_mod
+
local actions = setmetatable({}, {
__index = function(_, k)
error("Actions does not have a value: " .. tostring(k))
end
})
-local action_mt = {
- __call = function(t, ...)
- local values = {}
- for _, v in ipairs(t) do
- local result = {v(...)}
- for _, res in ipairs(result) do
- table.insert(values, res)
- end
- end
-
- return unpack(values)
- end,
-
- __add = function(lhs, rhs)
- local new_actions = {}
- for _, v in ipairs(lhs) do
- table.insert(new_actions, v)
- end
-
- for _, v in ipairs(rhs) do
- table.insert(new_actions, v)
- end
-
- return setmetatable(new_actions, getmetatable(lhs))
- end
-}
-
-local transform_action = function(a)
- return setmetatable({a}, action_mt)
-end
-
--- Get the current picker object for the prompt
function actions.get_current_picker(prompt_bufnr)
return state.get_status(prompt_bufnr).picker
@@ -68,8 +39,8 @@ function actions.add_selection(prompt_bufnr)
end
--- Get the current entry
-function actions.get_selected_entry(prompt_bufnr)
- return actions.get_current_picker(prompt_bufnr):get_selection()
+function actions.get_selected_entry()
+ return state.get_global_key('selected_entry')
end
function actions.preview_scrolling_up(prompt_bufnr)
@@ -81,7 +52,7 @@ function actions.preview_scrolling_down(prompt_bufnr)
end
-- TODO: It seems sometimes we get bad styling.
-local function goto_file_selection(prompt_bufnr, command)
+function actions._goto_file_selection(prompt_bufnr, command)
local entry = actions.get_selected_entry(prompt_bufnr)
if not entry then
@@ -95,7 +66,7 @@ local function goto_file_selection(prompt_bufnr, command)
-- TODO: Check for off-by-one
row = entry.row or entry.lnum
col = entry.col
- else
+ elseif not entry.bufnr then
-- TODO: Might want to remove this and force people
-- to put stuff into `filename`
local value = entry.value
@@ -124,11 +95,11 @@ local function goto_file_selection(prompt_bufnr, command)
actions.close(prompt_bufnr)
- filename = path.normalize(filename, vim.fn.getcwd())
-
if entry_bufnr then
vim.cmd(string.format(":%s #%d", command, entry_bufnr))
else
+ filename = path.normalize(filename, vim.fn.getcwd())
+
local bufnr = vim.api.nvim_get_current_buf()
if filename ~= vim.api.nvim_buf_get_name(bufnr) then
vim.cmd(string.format(":%s %s", command, filename))
@@ -151,19 +122,19 @@ function actions.center(_)
end
function actions.goto_file_selection_edit(prompt_bufnr)
- goto_file_selection(prompt_bufnr, "edit")
+ actions._goto_file_selection(prompt_bufnr, "edit")
end
function actions.goto_file_selection_split(prompt_bufnr)
- goto_file_selection(prompt_bufnr, "new")
+ actions._goto_file_selection(prompt_bufnr, "new")
end
function actions.goto_file_selection_vsplit(prompt_bufnr)
- goto_file_selection(prompt_bufnr, "vnew")
+ actions._goto_file_selection(prompt_bufnr, "vnew")
end
function actions.goto_file_selection_tabedit(prompt_bufnr)
- goto_file_selection(prompt_bufnr, "tabedit")
+ actions._goto_file_selection(prompt_bufnr, "tabedit")
end
function actions.close_pum(_)
@@ -218,10 +189,9 @@ actions.insert_value = function(prompt_bufnr)
return entry.value
end
-for k, v in pairs(actions) do
- actions[k] = transform_action(v)
-end
-
-actions._transform_action = transform_action
-
+-- ==================================================
+-- Transforms modules and sets the corect metatables.
+-- ==================================================
+actions = transform_mod(actions)
return actions
+
diff --git a/lua/telescope/actions/mt.lua b/lua/telescope/actions/mt.lua
new file mode 100644
index 0000000..909e7bb
--- /dev/null
+++ b/lua/telescope/actions/mt.lua
@@ -0,0 +1,96 @@
+
+local action_mt = {}
+
+action_mt.create = function(mod)
+ local mt = {
+ __call = function(t, ...)
+ local values = {}
+ for _, v in ipairs(t) do
+ local func = t._replacements[v] or mod[v]
+
+ if t._pre[v] then
+ t._pre[v](...)
+ end
+
+ local result = {func(...)}
+ for _, res in ipairs(result) do
+ table.insert(values, res)
+ end
+
+ if t._post[v] then
+ t._post[v](...)
+ end
+ end
+
+ return unpack(values)
+ end,
+
+ __add = function(lhs, rhs)
+ local new_actions = {}
+ for _, v in ipairs(lhs) do
+ table.insert(new_actions, v)
+ end
+
+ for _, v in ipairs(rhs) do
+ table.insert(new_actions, v)
+ end
+
+ return setmetatable(new_actions, getmetatable(lhs))
+ end,
+
+ _pre = {},
+ _replacements = {},
+ _post = {},
+ }
+
+ mt.__index = mt
+
+ mt.clear = function()
+ mt._pre = {}
+ mt._replacements = {}
+ mt._post = {}
+ end
+
+ --- Replace the reference to the function with a new one temporarily
+ function mt:replace(v)
+ assert(#self == 1, "Cannot replace an already combined action")
+
+ local action_name = self[1]
+ mt._replacements[action_name] = v
+ end
+
+ function mt:enhance(opts)
+ assert(#self == 1, "Cannot enhance already combined actions")
+
+ local action_name = self[1]
+ if opts.pre then
+ mt._pre[action_name] = opts.pre
+ end
+
+ if opts.post then
+ mt._post[action_name] = opts.post
+ end
+ end
+
+ return mt
+end
+
+action_mt.transform = function(k, mt)
+ return setmetatable({k}, mt)
+end
+
+action_mt.transform_mod = function(mod)
+ local mt = action_mt.create(mod)
+
+ local redirect = {}
+
+ for k, _ in pairs(mod) do
+ redirect[k] = action_mt.transform(k, mt)
+ end
+
+ redirect._clear = mt.clear
+
+ return redirect
+end
+
+return action_mt
diff --git a/lua/telescope/builtin.lua b/lua/telescope/builtin.lua
index c792f54..4102906 100644
--- a/lua/telescope/builtin.lua
+++ b/lua/telescope/builtin.lua
@@ -804,12 +804,15 @@ builtin.current_buffer_fuzzy_find = function(opts)
table.insert(lines_with_numbers, {k, v})
end
+ local bufnr = vim.api.nvim_get_current_buf()
+
pickers.new(opts, {
prompt_title = 'Current Buffer Fuzzy',
finder = finders.new_table {
results = lines_with_numbers,
entry_maker = function(enumerated_line)
return {
+ bufnr = bufnr,
display = enumerated_line[2],
ordinal = enumerated_line[2],
@@ -818,17 +821,13 @@ builtin.current_buffer_fuzzy_find = function(opts)
end
},
sorter = sorters.get_generic_fuzzy_sorter(),
- attach_mappings = function(prompt_bufnr, map)
- local goto_line = function()
- local selection = actions.get_selected_entry(prompt_bufnr)
- actions.close(prompt_bufnr)
-
- vim.api.nvim_win_set_cursor(0, {selection.lnum, 0})
- vim.cmd [[stopinsert]]
- end
-
- map('n', '<CR>', goto_line)
- map('i', '<CR>', goto_line)
+ attach_mappings = function(prompt_bufnr)
+ actions._goto_file_selection:enhance {
+ post = vim.schedule_wrap(function()
+ local selection = actions.get_selected_entry(prompt_bufnr)
+ vim.api.nvim_win_set_cursor(0, {selection.lnum, 0})
+ end),
+ }
return true
end
diff --git a/lua/telescope/pickers.lua b/lua/telescope/pickers.lua
index d3bba7c..ac31348 100644
--- a/lua/telescope/pickers.lua
+++ b/lua/telescope/pickers.lua
@@ -61,6 +61,9 @@ function Picker:new(opts)
error("layout_strategy and get_window_options are not compatible keys")
end
+ -- Reset actions for any replaced / enhanced actions.
+ actions._clear()
+
local layout_strategy = get_default(opts.layout_strategy, config.values.layout_strategy)
return setmetatable({
@@ -708,6 +711,8 @@ function Picker:set_selection(row)
local status = state.get_status(self.prompt_bufnr)
local results_bufnr = status.results_bufnr
+ state.set_global_key("selected_entry", entry)
+
if not vim.api.nvim_buf_is_valid(results_bufnr) then
return
end
diff --git a/lua/telescope/state.lua b/lua/telescope/state.lua
index a014a0d..6a06eb1 100644
--- a/lua/telescope/state.lua
+++ b/lua/telescope/state.lua
@@ -1,12 +1,21 @@
local state = {}
TelescopeGlobalState = TelescopeGlobalState or {}
+TelescopeGlobalState.global = TelescopeGlobalState.global or {}
--- Set the status for a particular prompt bufnr
function state.set_status(prompt_bufnr, status)
TelescopeGlobalState[prompt_bufnr] = status
end
+function state.set_global_key(key, value)
+ TelescopeGlobalState.global[key] = value
+end
+
+function state.get_global_key(key)
+ return TelescopeGlobalState.global[key]
+end
+
function state.get_status(prompt_bufnr)
return TelescopeGlobalState[prompt_bufnr] or {}
end
diff --git a/lua/tests/automated/action_spec.lua b/lua/tests/automated/action_spec.lua
new file mode 100644
index 0000000..e85cb3c
--- /dev/null
+++ b/lua/tests/automated/action_spec.lua
@@ -0,0 +1,164 @@
+require('plenary.test_harness'):setup_busted()
+
+local transform_mod = require('telescope.actions.mt').transform_mod
+
+local eq = function(a, b)
+ assert.are.same(a, b)
+end
+
+describe('actions', function()
+ it('should allow creating custom actions', function()
+ local a = transform_mod {
+ x = function() return 5 end,
+ }
+
+
+ eq(5, a.x())
+ end)
+
+ it('allows adding actions', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ local x_plus_y = a.x + a.y
+
+ eq({"x", "y"}, {x_plus_y()})
+ end)
+
+ it('ignores nils from added actions', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ nil_maker = function() return nil end,
+ }
+
+ local x_plus_y = a.x + a.nil_maker + a.y
+
+ eq({"x", "y"}, {x_plus_y()})
+ end)
+
+ it('allows overriding an action', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ -- actions.file_goto_selection_edit:replace(...)
+ a.x:replace(function() return "foo" end)
+ eq("foo", a.x())
+
+ a._clear()
+ eq("x", a.x())
+ end)
+
+ it('enhance.pre', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ local called_pre = false
+
+ a.y:enhance {
+ pre = function()
+ called_pre = true
+ end,
+ }
+ eq("y", a.y())
+ eq(true, called_pre)
+ end)
+
+ it('enhance.post', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ local called_post = false
+
+ a.y:enhance {
+ post = function()
+ called_post = true
+ end,
+ }
+ eq("y", a.y())
+ eq(true, called_post)
+ end)
+
+ it('can call both', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ local called_count = 0
+ local count_inc = function()
+ called_count = called_count + 1
+ end
+
+ a.y:enhance {
+ pre = count_inc,
+ post = count_inc,
+ }
+
+ eq("y", a.y())
+ eq(2, called_count)
+ end)
+
+ it('can call both even when combined', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ local called_count = 0
+ local count_inc = function()
+ called_count = called_count + 1
+ end
+
+ a.y:enhance {
+ pre = count_inc,
+ post = count_inc,
+ }
+
+ a.x:enhance {
+ post = count_inc
+ }
+
+ local x_plus_y = a.x + a.y
+ x_plus_y()
+
+ eq(3, called_count)
+ end)
+
+ it('clears enhance', function()
+ local a = transform_mod {
+ x = function() return "x" end,
+ y = function() return "y" end,
+ }
+
+ local called_post = false
+
+ a.y:enhance {
+ post = function()
+ called_post = true
+ end,
+ }
+
+ a._clear()
+
+ eq("y", a.y())
+ eq(false, called_post)
+ end)
+
+ it('handles passing arguments', function()
+ local a = transform_mod {
+ x = function(bufnr) return string.format("bufnr: %s") end,
+ }
+
+ a.x:replace(function(bufnr) return string.format("modified: %s", bufnr) end)
+ eq("modified: 5", a.x(5))
+ end)
+end)