diff options
Diffstat (limited to 'lua/blink/cmp/lib/async.lua')
| -rw-r--r-- | lua/blink/cmp/lib/async.lua | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/lua/blink/cmp/lib/async.lua b/lua/blink/cmp/lib/async.lua new file mode 100644 index 0000000..b9c39ac --- /dev/null +++ b/lua/blink/cmp/lib/async.lua @@ -0,0 +1,217 @@ +--- Allows chaining of async operations without callback hell +--- +--- @class blink.cmp.Task +--- @field status blink.cmp.TaskStatus +--- @field result any | nil +--- @field error any | nil +--- @field new fun(fn: fun(resolve: fun(result: any), reject: fun(err: any))): blink.cmp.Task +--- +--- @field cancel fun(self: blink.cmp.Task) +--- @field map fun(self: blink.cmp.Task, fn: fun(result: any): blink.cmp.Task | any): blink.cmp.Task +--- @field catch fun(self: blink.cmp.Task, fn: fun(err: any): blink.cmp.Task | any): blink.cmp.Task +--- +--- @field on_completion fun(self: blink.cmp.Task, cb: fun(result: any)) +--- @field on_failure fun(self: blink.cmp.Task, cb: fun(err: any)) +--- @field on_cancel fun(self: blink.cmp.Task, cb: fun()) +--- @field _completion_cbs function[] +--- @field _failure_cbs function[] +--- @field _cancel_cbs function[] +--- @field _cancel? fun() +local task = { + __task = true, +} + +---@enum blink.cmp.TaskStatus +local STATUS = { + RUNNING = 1, + COMPLETED = 2, + FAILED = 3, + CANCELLED = 4, +} + +function task.new(fn) + local self = setmetatable({}, { __index = task }) + self.status = STATUS.RUNNING + self._completion_cbs = {} + self._failure_cbs = {} + self._cancel_cbs = {} + self.result = nil + self.error = nil + + local resolve = function(result) + if self.status ~= STATUS.RUNNING then return end + + self.status = STATUS.COMPLETED + self.result = result + + for _, cb in ipairs(self._completion_cbs) do + cb(result) + end + end + + local reject = function(err) + if self.status ~= STATUS.RUNNING then return end + + self.status = STATUS.FAILED + self.error = err + + for _, cb in ipairs(self._failure_cbs) do + cb(err) + end + end + + local success, cancel_fn_or_err = pcall(function() return fn(resolve, reject) end) + + if not success then + reject(cancel_fn_or_err) + elseif type(cancel_fn_or_err) == 'function' then + self._cancel = cancel_fn_or_err + end + + return self +end + +function task:cancel() + if self.status ~= STATUS.RUNNING then return end + self.status = STATUS.CANCELLED + + if self._cancel ~= nil then self._cancel() end + for _, cb in ipairs(self._cancel_cbs) do + cb() + end +end + +--- mappings + +function task:map(fn) + local chained_task + chained_task = task.new(function(resolve, reject) + self:on_completion(function(result) + local success, mapped_result = pcall(fn, result) + if not success then + reject(mapped_result) + return + end + + if type(mapped_result) == 'table' and mapped_result.__task then + mapped_result:on_completion(resolve) + mapped_result:on_failure(reject) + mapped_result:on_cancel(function() chained_task:cancel() end) + return + end + resolve(mapped_result) + end) + self:on_failure(reject) + self:on_cancel(function() chained_task:cancel() end) + return function() chained_task:cancel() end + end) + return chained_task +end + +function task:catch(fn) + local chained_task + chained_task = task.new(function(resolve, reject) + self:on_completion(resolve) + self:on_failure(function(err) + local success, mapped_err = pcall(fn, err) + if not success then + reject(mapped_err) + return + end + + if type(mapped_err) == 'table' and mapped_err.__task then + mapped_err:on_completion(resolve) + mapped_err:on_failure(reject) + mapped_err:on_cancel(function() chained_task:cancel() end) + return + end + resolve(mapped_err) + end) + self:on_cancel(function() chained_task:cancel() end) + return function() chained_task:cancel() end + end) + return chained_task +end + +--- events + +function task:on_completion(cb) + if self.status == STATUS.COMPLETED then + cb(self.result) + elseif self.status == STATUS.RUNNING then + table.insert(self._completion_cbs, cb) + end + return self +end + +function task:on_failure(cb) + if self.status == STATUS.FAILED then + cb(self.error) + elseif self.status == STATUS.RUNNING then + table.insert(self._failure_cbs, cb) + end + return self +end + +function task:on_cancel(cb) + if self.status == STATUS.CANCELLED then + cb() + elseif self.status == STATUS.RUNNING then + table.insert(self._cancel_cbs, cb) + end + return self +end + +--- utils + +function task.await_all(tasks) + if #tasks == 0 then + return task.new(function(resolve) resolve({}) end) + end + + local all_task + all_task = task.new(function(resolve, reject) + local results = {} + local has_resolved = {} + + local function resolve_if_completed() + -- we can't check #results directly because a table like + -- { [2] = { ... } } has a length of 2 + for i = 1, #tasks do + if has_resolved[i] == nil then return end + end + resolve(results) + end + + for idx, task in ipairs(tasks) do + task:on_completion(function(result) + results[idx] = result + has_resolved[idx] = true + resolve_if_completed() + end) + task:on_failure(function(err) + reject(err) + for _, task in ipairs(tasks) do + task:cancel() + end + end) + task:on_cancel(function() + for _, sub_task in ipairs(tasks) do + sub_task:cancel() + end + if all_task == nil then + vim.schedule(function() all_task:cancel() end) + else + all_task:cancel() + end + end) + end + end) + return all_task +end + +function task.empty() + return task.new(function(resolve) resolve() end) +end + +return { task = task, STATUS = STATUS } |
