diff options
Diffstat (limited to 'lua/blink/cmp/fuzzy')
| -rw-r--r-- | lua/blink/cmp/fuzzy/download/files.lua | 181 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/download/git.lua | 70 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/download/init.lua | 170 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/download/system.lua | 74 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/frecency.rs | 153 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/fuzzy.rs | 183 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/init.lua | 118 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/keyword.rs | 84 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/lib.rs | 156 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/lsp_item.rs | 46 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/rust.lua | 20 | ||||
| -rw-r--r-- | lua/blink/cmp/fuzzy/sort.lua | 48 |
12 files changed, 1303 insertions, 0 deletions
diff --git a/lua/blink/cmp/fuzzy/download/files.lua b/lua/blink/cmp/fuzzy/download/files.lua new file mode 100644 index 0000000..28bbac5 --- /dev/null +++ b/lua/blink/cmp/fuzzy/download/files.lua @@ -0,0 +1,181 @@ +local async = require('blink.cmp.lib.async') +local utils = require('blink.cmp.lib.utils') +local system = require('blink.cmp.fuzzy.download.system') + +local function get_lib_extension() + if jit.os:lower() == 'mac' or jit.os:lower() == 'osx' then return '.dylib' end + if jit.os:lower() == 'windows' then return '.dll' end + return '.so' +end + +local current_file_dir = debug.getinfo(1).source:match('@?(.*/)') +local current_file_dir_parts = vim.split(current_file_dir, '/') +local root_dir = table.concat(utils.slice(current_file_dir_parts, 1, #current_file_dir_parts - 6), '/') +local lib_folder = root_dir .. '/target/release' +local lib_filename = 'libblink_cmp_fuzzy' .. get_lib_extension() +local lib_path = lib_folder .. '/' .. lib_filename +local checksum_filename = lib_filename .. '.sha256' +local checksum_path = lib_path .. '.sha256' +local version_path = lib_folder .. '/version' + +local files = { + get_lib_extension = get_lib_extension, + root_dir = root_dir, + lib_folder = lib_folder, + lib_filename = lib_filename, + lib_path = lib_path, + checksum_path = checksum_path, + checksum_filename = checksum_filename, + version_path = version_path, +} + +--- Checksums --- + +function files.get_checksum() + return files.read_file(files.checksum_path):map(function(checksum) return vim.split(checksum, ' ')[1] end) +end + +function files.get_checksum_for_file(path) + return async.task.new(function(resolve, reject) + local os = system.get_info() + local args + if os == 'linux' then + args = { 'sha256sum', path } + elseif os == 'mac' or os == 'osx' then + args = { 'shasum', '-a', '256', path } + elseif os == 'windows' then + args = { 'certutil', '-hashfile', path, 'SHA256' } + end + + vim.system(args, {}, function(out) + if out.code ~= 0 then return reject('Failed to calculate checksum of pre-built binary: ' .. out.stderr) end + + local stdout = out.stdout or '' + if os == 'windows' then stdout = vim.split(stdout, '\r\n')[2] end + -- We get an output like 'sha256sum filename' on most systems, so we grab just the checksum + return resolve(vim.split(stdout, ' ')[1]) + end) + end) +end + +function files.verify_checksum() + return async.task + .await_all({ files.get_checksum(), files.get_checksum_for_file(files.lib_path) }) + :map(function(checksums) + assert(#checksums == 2, 'Expected 2 checksums, got ' .. #checksums) + assert(checksums[1] and checksums[2], 'Expected checksums to be non-nil') + assert( + checksums[1] == checksums[2], + 'Checksum of pre-built binary does not match. Expected "' .. checksums[1] .. '", got "' .. checksums[2] .. '"' + ) + end) +end + +--- Prebuilt binary --- + +function files.get_version() + return files + .read_file(files.version_path) + :map(function(version) + if #version == 40 then + return { sha = version } + else + return { tag = version } + end + end) + :catch(function() return {} end) +end + +--- @param version string +--- @return blink.cmp.Task +function files.set_version(version) + return files + .create_dir(files.root_dir .. '/target') + :map(function() return files.create_dir(files.lib_folder) end) + :map(function() return files.write_file(files.version_path, version) end) +end + +--- Filesystem helpers --- + +--- @param path string +--- @return blink.cmp.Task +function files.read_file(path) + return async.task.new(function(resolve, reject) + vim.uv.fs_open(path, 'r', 438, function(open_err, fd) + if open_err or fd == nil then return reject(open_err or 'Unknown error') end + vim.uv.fs_read(fd, 1024, 0, function(read_err, data) + vim.uv.fs_close(fd, function() end) + if read_err or data == nil then return reject(read_err or 'Unknown error') end + return resolve(data) + end) + end) + end) +end + +--- @param path string +--- @param data string +--- @return blink.cmp.Task +function files.write_file(path, data) + return async.task.new(function(resolve, reject) + vim.uv.fs_open(path, 'w', 438, function(open_err, fd) + if open_err or fd == nil then return reject(open_err or 'Unknown error') end + vim.uv.fs_write(fd, data, 0, function(write_err) + vim.uv.fs_close(fd, function() end) + if write_err then return reject(write_err) end + return resolve() + end) + end) + end) +end + +--- @param path string +--- @return blink.cmp.Task +function files.exists(path) + return async.task.new(function(resolve) + vim.uv.fs_stat(path, function(err) resolve(not err) end) + end) +end + +--- @param path string +--- @return blink.cmp.Task +function files.stat(path) + return async.task.new(function(resolve, reject) + vim.uv.fs_stat(path, function(err, stat) + if err then return reject(err) end + resolve(stat) + end) + end) +end + +--- @param path string +--- @return blink.cmp.Task +function files.create_dir(path) + return files + .stat(path) + :map(function(stat) return stat.type == 'directory' end) + :catch(function() return false end) + :map(function(exists) + if exists then return end + + return async.task.new(function(resolve, reject) + vim.uv.fs_mkdir(path, 511, function(err) + if err then return reject(err) end + resolve() + end) + end) + end) +end + +--- Renames a file +--- @param old_path string +--- @param new_path string +function files.rename(old_path, new_path) + return async.task.new(function(resolve, reject) + vim.uv.fs_rename(old_path, new_path, function(err) + if err then return reject(err) end + resolve() + end) + end) +end + +return files diff --git a/lua/blink/cmp/fuzzy/download/git.lua b/lua/blink/cmp/fuzzy/download/git.lua new file mode 100644 index 0000000..63a7646 --- /dev/null +++ b/lua/blink/cmp/fuzzy/download/git.lua @@ -0,0 +1,70 @@ +local async = require('blink.cmp.lib.async') +local files = require('blink.cmp.fuzzy.download.files') +local git = {} + +function git.get_version() + return async.task.await_all({ git.get_tag(), git.get_sha() }):map( + function(results) + return { + tag = results[1], + sha = results[2], + } + end + ) +end + +function git.get_tag() + return async.task.new(function(resolve, reject) + -- If repo_dir is nil, no git reposiory is found, similar to `out.code == 128` + local repo_dir = vim.fs.root(files.root_dir, '.git') + if not repo_dir then resolve() end + + vim.system({ + 'git', + '--git-dir', + vim.fs.joinpath(repo_dir, '.git'), + '--work-tree', + repo_dir, + 'describe', + '--tags', + '--exact-match', + }, { cwd = files.root_dir }, function(out) + if out.code == 128 then return resolve() end + if out.code ~= 0 then + return reject('While getting git tag, git exited with code ' .. out.code .. ': ' .. out.stderr) + end + + local lines = vim.split(out.stdout, '\n') + if not lines[1] then return reject('Expected atleast 1 line of output from git describe') end + return resolve(lines[1]) + end) + end) +end + +function git.get_sha() + return async.task.new(function(resolve, reject) + -- If repo_dir is nil, no git reposiory is found, similar to `out.code == 128` + local repo_dir = vim.fs.root(files.root_dir, '.git') + if not repo_dir then resolve() end + + vim.system({ + 'git', + '--git-dir', + vim.fs.joinpath(repo_dir, '.git'), + '--work-tree', + repo_dir, + 'rev-parse', + 'HEAD', + }, { cwd = files.root_dir }, function(out) + if out.code == 128 then return resolve() end + if out.code ~= 0 then + return reject('While getting git sha, git exited with code ' .. out.code .. ': ' .. out.stderr) + end + + local sha = vim.split(out.stdout, '\n')[1] + return resolve(sha) + end) + end) +end + +return git diff --git a/lua/blink/cmp/fuzzy/download/init.lua b/lua/blink/cmp/fuzzy/download/init.lua new file mode 100644 index 0000000..085becd --- /dev/null +++ b/lua/blink/cmp/fuzzy/download/init.lua @@ -0,0 +1,170 @@ +local download_config = require('blink.cmp.config').fuzzy.prebuilt_binaries +local async = require('blink.cmp.lib.async') +local git = require('blink.cmp.fuzzy.download.git') +local files = require('blink.cmp.fuzzy.download.files') +local system = require('blink.cmp.fuzzy.download.system') + +local download = {} + +--- @param callback fun(err: string | nil) +function download.ensure_downloaded(callback) + callback = vim.schedule_wrap(callback) + + if not download_config.download then return callback() end + + async.task + .await_all({ git.get_version(), files.get_version() }) + :map(function(results) + return { + git = results[1], + current = results[2], + } + end) + :map(function(version) + local target_git_tag = download_config.force_version or version.git.tag + + -- not built locally, not on a git tag, error + assert( + version.current.sha ~= nil or target_git_tag ~= nil, + "\nDetected an out of date or missing fuzzy matching library. Can't download from github due to not being on a git tag and no `fuzzy.prebuilt_binaries.force_version` is set." + .. '\nEither run `cargo build --release` via your package manager, switch to a git tag, or set `fuzzy.prebuilt_binaries.force_version` in config. ' + .. 'See the docs for more info.' + ) + + -- built locally, ignore + if + not download_config.force_version + and ( + version.current.sha == version.git.sha + or version.current.sha ~= nil and download_config.ignore_version_mismatch + ) + then + return + end + + -- built locally but outdated and not on a git tag, error + if + not download_config.force_version + and version.current.sha ~= nil + and version.current.sha ~= version.git.sha + then + assert( + target_git_tag or download_config.ignore_version_mismatch, + "\nFound an outdated version of the fuzzy matching library, but can't download from github due to not being on a git tag. " + .. '\n!! FOR DEVELOPERS !!, set `fuzzy.prebuilt_binaries.ignore_version_mismatch = true` in config. ' + .. '\n!! FOR USERS !!, either run `cargo build --release` via your package manager, switch to a git tag, or set `fuzzy.prebuilt_binaries.force_version` in config. ' + .. 'See the docs for more info.' + ) + if not download_config.ignore_version_mismatch then + vim.schedule( + function() + vim.notify( + '[blink.cmp]: Found an outdated version of the fuzzy matching library built locally', + vim.log.levels.INFO, + { title = 'blink.cmp' } + ) + end + ) + end + end + + -- already downloaded and the correct version, just verify the checksum, and re-download if checksum fails + if version.current.tag ~= nil and version.current.tag == target_git_tag then + return files.verify_checksum():catch(function(err) + vim.schedule(function() + vim.notify(err, vim.log.levels.WARN, { title = 'blink.cmp' }) + vim.notify( + '[blink.cmp]: Pre-built binary failed checksum verification, re-downloading', + vim.log.levels.WARN, + { title = 'blink.cmp' } + ) + end) + return download.download(target_git_tag) + end) + end + + -- unknown state + if not target_git_tag then error('Unknown error while getting pre-built binary. Consider re-installing') end + + -- download as per usual + vim.schedule( + function() vim.notify('[blink.cmp]: Downloading pre-built binary', vim.log.levels.INFO, { title = 'blink.cmp' }) end + ) + return download.download(target_git_tag) + end) + :map(function() callback() end) + :catch(function(err) callback(err) end) +end + +function download.download(version) + -- NOTE: we set the version to 'v0.0.0' to avoid a failure causing the pre-built binary being marked as locally built + return files + .set_version('v0.0.0') + :map(function() return download.from_github(version) end) + :map(function() return files.verify_checksum() end) + :map(function() return files.set_version(version) end) +end + +--- @param tag string +--- @return blink.cmp.Task +function download.from_github(tag) + return system.get_triple():map(function(system_triple) + if not system_triple then + return error( + 'Your system is not supported by pre-built binaries. You must run cargo build --release via your package manager with rust nightly. See the README for more info.' + ) + end + + local base_url = 'https://github.com/saghen/blink.cmp/releases/download/' .. tag .. '/' + local library_url = base_url .. system_triple .. files.get_lib_extension() + local checksum_url = base_url .. system_triple .. files.get_lib_extension() .. '.sha256' + + return async + .task + .await_all({ + download.download_file(library_url, files.lib_filename .. '.tmp'), + download.download_file(checksum_url, files.checksum_filename), + }) + -- Mac caches the library in the kernel, so updating in place causes a crash + -- We instead write to a temporary file and rename it, as mentioned in: + -- https://developer.apple.com/documentation/security/updating-mac-software + :map( + function() + return files.rename( + files.lib_folder .. '/' .. files.lib_filename .. '.tmp', + files.lib_folder .. '/' .. files.lib_filename + ) + end + ) + end) +end + +--- @param url string +--- @param filename string +--- @return blink.cmp.Task +function download.download_file(url, filename) + return async.task.new(function(resolve, reject) + local args = { 'curl' } + vim.list_extend(args, download_config.extra_curl_args) + vim.list_extend(args, { + '--fail', -- Fail on 4xx/5xx + '--location', -- Follow redirects + '--silent', -- Don't show progress + '--show-error', -- Show errors, even though we're using --silent + '--create-dirs', + '--output', + files.lib_folder .. '/' .. filename, + url, + }) + + vim.system(args, {}, function(out) + if out.code ~= 0 then + reject('Failed to download ' .. filename .. 'for pre-built binaries: ' .. out.stderr) + else + resolve() + end + end) + end) +end + +return download diff --git a/lua/blink/cmp/fuzzy/download/system.lua b/lua/blink/cmp/fuzzy/download/system.lua new file mode 100644 index 0000000..7b83a0f --- /dev/null +++ b/lua/blink/cmp/fuzzy/download/system.lua @@ -0,0 +1,74 @@ +local download_config = require('blink.cmp.config').fuzzy.prebuilt_binaries +local async = require('blink.cmp.lib.async') +local system = {} + +system.triples = { + mac = { + arm = 'aarch64-apple-darwin', + x64 = 'x86_64-apple-darwin', + }, + windows = { + x64 = 'x86_64-pc-windows-msvc', + }, + linux = { + android = 'aarch64-linux-android', + arm = function(libc) return 'aarch64-unknown-linux-' .. libc end, + x64 = function(libc) return 'x86_64-unknown-linux-' .. libc end, + }, +} + +--- Gets the operating system and architecture of the current system +--- @return string, string +function system.get_info() + local os = jit.os:lower() + if os == 'osx' then os = 'mac' end + local arch = jit.arch:lower():match('arm') and 'arm' or jit.arch:lower():match('x64') and 'x64' or nil + return os, arch +end + +--- Gets the system triple for the current system +--- I.e. `x86_64-unknown-linux-gnu` or `aarch64-apple-darwin` +--- @return blink.cmp.Task +function system.get_triple() + return async.task.new(function(resolve) + if download_config.force_system_triple then return resolve(download_config.force_system_triple) end + + local os, arch = system.get_info() + local triples = system.triples[os] + + if os == 'linux' then + if vim.fn.has('android') == 1 then return resolve(triples.android) end + + vim.uv.fs_stat('/etc/alpine-release', function(err, is_alpine) + local libc = (not err and is_alpine) and 'musl' or 'gnu' + local triple = triples[arch] + return resolve(triple and type(triple) == 'function' and triple(libc) or triple) + end) + else + return resolve(triples[arch]) + end + end) +end + +--- Same as `system.get_triple` but synchronous +--- @see system.get_triple +--- @return string | nil +function system.get_triple_sync() + if download_config.force_system_triple then return download_config.force_system_triple end + + local os, arch = system.get_info() + local triples = system.triples[os] + + if os == 'linux' then + if vim.fn.has('android') == 1 then return triples.android end + + local success, is_alpine = pcall(vim.uv.fs_stat, '/etc/alpine-release') + local libc = (success and is_alpine) and 'musl' or 'gnu' + local triple = triples[arch] + return triple and type(triple) == 'function' and triple(libc) or triple + else + return triples[arch] + end +end + +return system diff --git a/lua/blink/cmp/fuzzy/frecency.rs b/lua/blink/cmp/fuzzy/frecency.rs new file mode 100644 index 0000000..a672598 --- /dev/null +++ b/lua/blink/cmp/fuzzy/frecency.rs @@ -0,0 +1,153 @@ +use crate::lsp_item::LspItem; +use heed::{types::*, EnvFlags}; +use heed::{Database, Env, EnvOpenOptions}; +use mlua::Result as LuaResult; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[derive(Clone, Serialize, Deserialize)] +struct CompletionItemKey { + label: String, + kind: u32, + source_id: String, +} + +impl From<&LspItem> for CompletionItemKey { + fn from(item: &LspItem) -> Self { + Self { + label: item.label.clone(), + kind: item.kind, + source_id: item.source_id.clone(), + } + } +} + +#[derive(Debug)] +pub struct FrecencyTracker { + env: Env, + db: Database<SerdeBincode<CompletionItemKey>, SerdeBincode<Vec<u64>>>, + access_thresholds: Vec<(f64, u64)>, +} + +impl FrecencyTracker { + pub fn new(db_path: &str, use_unsafe_no_lock: bool) -> LuaResult<Self> { + fs::create_dir_all(db_path).map_err(|err| { + mlua::Error::RuntimeError( + "Failed to create frecency database directory: ".to_string() + &err.to_string(), + ) + })?; + let env = unsafe { + let mut opts = EnvOpenOptions::new(); + if use_unsafe_no_lock { + opts.flags(EnvFlags::NO_LOCK | EnvFlags::NO_SYNC | EnvFlags::NO_META_SYNC); + } + opts.open(db_path).map_err(|err| { + mlua::Error::RuntimeError( + "Failed to open frecency database: ".to_string() + &err.to_string(), + ) + })? + }; + env.clear_stale_readers().map_err(|err| { + mlua::Error::RuntimeError( + "Failed to clear stale readers for frecency database: ".to_string() + + &err.to_string(), + ) + })?; + + // we will open the default unnamed database + let mut wtxn = env.write_txn().map_err(|err| { + mlua::Error::RuntimeError( + "Failed to open write transaction for frecency database: ".to_string() + + &err.to_string(), + ) + })?; + let db = env.create_database(&mut wtxn, None).map_err(|err| { + mlua::Error::RuntimeError( + "Failed to create frecency database: ".to_string() + &err.to_string(), + ) + })?; + + let access_thresholds = [ + (1., 1000 * 60 * 2), // 2 minutes + (0.2, 1000 * 60 * 60), // 1 hour + (0.1, 1000 * 60 * 60 * 24), // 1 day + (0.05, 1000 * 60 * 60 * 24 * 7), // 1 week + ] + .to_vec(); + + Ok(FrecencyTracker { + env: env.clone(), + db, + access_thresholds, + }) + } + + fn get_accesses(&self, item: &LspItem) -> LuaResult<Option<Vec<u64>>> { + let rtxn = self.env.read_txn().map_err(|err| { + mlua::Error::RuntimeError( + "Failed to start read transaction for frecency database: ".to_string() + + &err.to_string(), + ) + })?; + self.db + .get(&rtxn, &CompletionItemKey::from(item)) + .map_err(|err| { + mlua::Error::RuntimeError( + "Failed to read from frecency database: ".to_string() + &err.to_string(), + ) + }) + } + + fn get_now(&self) -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + } + + pub fn access(&mut self, item: &LspItem) -> LuaResult<()> { + let mut wtxn = self.env.write_txn().map_err(|err| { + mlua::Error::RuntimeError( + "Failed to start write transaction for frecency database: ".to_string() + + &err.to_string(), + ) + })?; + + let mut accesses = self.get_accesses(item)?.unwrap_or_default(); + accesses.push(self.get_now()); + + self.db + .put(&mut wtxn, &CompletionItemKey::from(item), &accesses) + .map_err(|err| { + mlua::Error::RuntimeError( + "Failed to write to frecency database: ".to_string() + &err.to_string(), + ) + })?; + + wtxn.commit().map_err(|err| { + mlua::Error::RuntimeError( + "Failed to commit write transaction for frecency database: ".to_string() + + &err.to_string(), + ) + })?; + + Ok(()) + } + + pub fn get_score(&self, item: &LspItem) -> i64 { + let accesses = self.get_accesses(item).unwrap_or(None).unwrap_or_default(); + let now = self.get_now(); + let mut score = 0.0; + 'outer: for access in &accesses { + let duration_since = now - access; + for (rank, threshold_duration_since) in &self.access_thresholds { + if duration_since < *threshold_duration_since { + score += rank; + continue 'outer; + } + } + } + score.min(4.) as i64 + } +} diff --git a/lua/blink/cmp/fuzzy/fuzzy.rs b/lua/blink/cmp/fuzzy/fuzzy.rs new file mode 100644 index 0000000..b09bb99 --- /dev/null +++ b/lua/blink/cmp/fuzzy/fuzzy.rs @@ -0,0 +1,183 @@ +// TODO: refactor this heresy + +use crate::frecency::FrecencyTracker; +use crate::keyword; +use crate::lsp_item::LspItem; +use mlua::prelude::*; +use mlua::FromLua; +use mlua::Lua; +use std::collections::HashMap; +use std::collections::HashSet; + +#[derive(Clone, Hash)] +pub struct FuzzyOptions { + match_suffix: bool, + use_typo_resistance: bool, + use_frecency: bool, + use_proximity: bool, + nearby_words: Option<Vec<String>>, + min_score: u16, +} + +impl FromLua for FuzzyOptions { + fn from_lua(value: LuaValue, _lua: &'_ Lua) -> LuaResult<Self> { + if let Some(tab) = value.as_table() { + let match_suffix: bool = tab.get("match_suffix").unwrap_or_default(); + let use_typo_resistance: bool = tab.get("use_typo_resistance").unwrap_or_default(); + let use_frecency: bool = tab.get("use_frecency").unwrap_or_default(); + let use_proximity: bool = tab.get("use_proximity").unwrap_or_default(); + let nearby_words: Option<Vec<String>> = tab.get("nearby_words").ok(); + let min_score: u16 = tab.get("min_score").unwrap_or_default(); + + Ok(FuzzyOptions { + match_suffix, + use_typo_resistance, + use_frecency, + use_proximity, + nearby_words, + min_score, + }) + } else { + Err(mlua::Error::FromLuaConversionError { + from: "LuaValue", + to: "FuzzyOptions".to_string(), + message: None, + }) + } + } +} + +fn group_by_needle( + line: &str, + cursor_col: usize, + haystack: &[String], + match_suffix: bool, +) -> HashMap<String, Vec<(usize, String)>> { + let mut items_by_needle: HashMap<String, Vec<(usize, String)>> = HashMap::new(); + for (idx, item_text) in haystack.iter().enumerate() { + let needle = keyword::guess_keyword_from_item(item_text, line, cursor_col, match_suffix); + let entry = items_by_needle.entry(needle).or_default(); + entry.push((idx, item_text.to_string())); + } + items_by_needle +} + +pub fn fuzzy( + line: &str, + cursor_col: usize, + haystack: &[LspItem], + frecency: &FrecencyTracker, + opts: FuzzyOptions, +) -> (Vec<i32>, Vec<u32>) { + let haystack_labels = haystack + .iter() + .map(|s| s.filter_text.clone().unwrap_or(s.label.clone())) + .collect::<Vec<_>>(); + let options = frizbee::Options { + prefilter: !opts.use_typo_resistance, + min_score: opts.min_score, + stable_sort: false, + ..Default::default() + }; + + // Items may have different fuzzy matching ranges, so we split them up by needle + let mut matches = group_by_needle(line, cursor_col, &haystack_labels, opts.match_suffix) + .into_iter() + // Match on each needle and combine + .flat_map(|(needle, haystack)| { + let mut matches = frizbee::match_list( + &needle, + &haystack + .iter() + .map(|(_, str)| str.as_str()) + .collect::<Vec<_>>(), + options, + ); + for mtch in matches.iter_mut() { + mtch.index_in_haystack = haystack[mtch.index_in_haystack].0; + } + matches + }) + .collect::<Vec<_>>(); + + matches.sort_by_key(|mtch| mtch.index_in_haystack); + for (idx, mtch) in matches.iter_mut().enumerate() { + mtch.index = idx; + } + + // Get the score for each match, adding score_offset, frecency and proximity bonus + let nearby_words: HashSet<String> = HashSet::from_iter(opts.nearby_words.unwrap_or_default()); + let match_scores = matches + .iter() + .map(|mtch| { + let frecency_score = if opts.use_frecency { + frecency.get_score(&haystack[mtch.index_in_haystack]) as i32 + } else { + 0 + }; + let nearby_words_score = if opts.use_proximity { + nearby_words + .get(&haystack_labels[mtch.index_in_haystack]) + .map(|_| 2) + .unwrap_or(0) + } else { + 0 + }; + let score_offset = haystack[mtch.index_in_haystack].score_offset; + + (mtch.score as i32) + frecency_score + nearby_words_score + score_offset + }) + .collect::<Vec<_>>(); + + // Find the highest score and filter out matches that are unreasonably lower than it + if opts.use_typo_resistance { + let max_score = matches.iter().map(|mtch| mtch.score).max().unwrap_or(0); + let secondary_min_score = max_score.max(16) - 16; + matches = matches + .into_iter() + .filter(|mtch| mtch.score >= secondary_min_score) + .collect::<Vec<_>>(); + } + + // Return scores and indices + ( + matches + .iter() + .map(|mtch| match_scores[mtch.index]) + .collect::<Vec<_>>(), + matches + .iter() + .map(|mtch| mtch.index_in_haystack as u32) + .collect::<Vec<_>>(), + ) +} + +pub fn fuzzy_matched_indices( + line: &str, + cursor_col: usize, + haystack: &[String], + match_suffix: bool, +) -> Vec<Vec<usize>> { + let mut matches = group_by_needle(line, cursor_col, haystack, match_suffix) + .into_iter() + .flat_map(|(needle, haystack)| { + frizbee::match_list_for_matched_indices( + &needle, + &haystack + .iter() + .map(|(_, str)| str.as_str()) + .collect::<Vec<_>>(), + ) + .into_iter() + .enumerate() + .map(|(idx, matched_indices)| (haystack[idx].0, matched_indices)) + .collect::<Vec<_>>() + }) + .collect::<Vec<_>>(); + matches.sort_by_key(|mtch| mtch.0); + + matches + .into_iter() + .map(|(_, matched_indices)| matched_indices) + .collect::<Vec<_>>() +} diff --git a/lua/blink/cmp/fuzzy/init.lua b/lua/blink/cmp/fuzzy/init.lua new file mode 100644 index 0000000..ad4db03 --- /dev/null +++ b/lua/blink/cmp/fuzzy/init.lua @@ -0,0 +1,118 @@ +local config = require('blink.cmp.config') + +--- @class blink.cmp.Fuzzy +local fuzzy = { + rust = require('blink.cmp.fuzzy.rust'), + haystacks_by_provider_cache = {}, + has_init_db = false, +} + +function fuzzy.init_db() + if fuzzy.has_init_db then return end + + fuzzy.rust.init_db(vim.fn.stdpath('data') .. '/blink/cmp/fuzzy.db', config.use_unsafe_no_lock) + + vim.api.nvim_create_autocmd('VimLeavePre', { + callback = fuzzy.rust.destroy_db, + }) + + fuzzy.has_init_db = true +end + +---@param item blink.cmp.CompletionItem +function fuzzy.access(item) + fuzzy.init_db() + + -- writing to the db takes ~10ms, so schedule writes in another thread + vim.uv + .new_work(function(itm, cpath) + package.cpath = cpath + require('blink.cmp.fuzzy.rust').access(vim.mpack.decode(itm)) + end, function() end) + :queue(vim.mpack.encode(item), package.cpath) +end + +---@param lines string +function fuzzy.get_words(lines) return fuzzy.rust.get_words(lines) end + +--- @param line string +--- @param cursor_col number +--- @param haystack string[] +--- @param range blink.cmp.CompletionKeywordRange +function fuzzy.fuzzy_matched_indices(line, cursor_col, haystack, range) + return fuzzy.rust.fuzzy_matched_indices(line, cursor_col, haystack, range == 'full') +end + +--- @param line string +--- @param cursor_col number +--- @param haystacks_by_provider table<string, blink.cmp.CompletionItem[]> +--- @param range blink.cmp.CompletionKeywordRange +--- @return blink.cmp.CompletionItem[] +function fuzzy.fuzzy(line, cursor_col, haystacks_by_provider, range) + fuzzy.init_db() + + for provider_id, haystack in pairs(haystacks_by_provider) do + -- set the provider items once since Lua <-> Rust takes the majority of the time + if fuzzy.haystacks_by_provider_cache[provider_id] ~= haystack then + fuzzy.haystacks_by_provider_cache[provider_id] = haystack + fuzzy.rust.set_provider_items(provider_id, haystack) + end + end + + -- get the nearby words + local cursor_row = vim.api.nvim_win_get_cursor(0)[1] + local start_row = math.max(0, cursor_row - 30) + local end_row = math.min(cursor_row + 30, vim.api.nvim_buf_line_count(0)) + local nearby_text = table.concat(vim.api.nvim_buf_get_lines(0, start_row, end_row, false), '\n') + local nearby_words = #nearby_text < 10000 and fuzzy.rust.get_words(nearby_text) or {} + + local keyword_start_col, keyword_end_col = + require('blink.cmp.fuzzy').get_keyword_range(line, cursor_col, config.completion.keyword.range) + local keyword_length = keyword_end_col - keyword_start_col + + local filtered_items = {} + for provider_id, haystack in pairs(haystacks_by_provider) do + -- perform fuzzy search + local scores, matched_indices = fuzzy.rust.fuzzy(line, cursor_col, provider_id, { + -- each matching char is worth 7 points (+ 1 for matching capitalization) + -- and it receives a bonus for capitalization, delimiter and prefix + -- so this should generally be good + -- TODO: make this configurable + -- TODO: instead of a min score, set X number of allowed typos + min_score = config.fuzzy.use_typo_resistance and (6 * keyword_length) or 0, + use_typo_resistance = config.fuzzy.use_typo_resistance, + use_frecency = config.fuzzy.use_frecency and keyword_length > 0, + use_proximity = config.fuzzy.use_proximity and keyword_length > 0, + sorts = config.fuzzy.sorts, + nearby_words = nearby_words, + match_suffix = range == 'full', + }) + + for idx, item_index in ipairs(matched_indices) do + local item = haystack[item_index + 1] + item.score = scores[idx] + table.insert(filtered_items, item) + end + end + + return require('blink.cmp.fuzzy.sort').sort(filtered_items, config.fuzzy.sorts) +end + +--- @param line string +--- @param col number +--- @param range? blink.cmp.CompletionKeywordRange +--- @return number, number +function fuzzy.get_keyword_range(line, col, range) + return require('blink.cmp.fuzzy.rust').get_keyword_range(line, col, range == 'full') +end + +--- @param item blink.cmp.CompletionItem +--- @param line string +--- @param col number +--- @param range blink.cmp.CompletionKeywordRange +--- @return number, number +function fuzzy.guess_edit_range(item, line, col, range) + return require('blink.cmp.fuzzy.rust').guess_edit_range(item, line, col, range == 'full') +end + +return fuzzy diff --git a/lua/blink/cmp/fuzzy/keyword.rs b/lua/blink/cmp/fuzzy/keyword.rs new file mode 100644 index 0000000..13d5020 --- /dev/null +++ b/lua/blink/cmp/fuzzy/keyword.rs @@ -0,0 +1,84 @@ +use lazy_static::lazy_static; +use regex::Regex; + +lazy_static! { + static ref BACKWARD_REGEX: Regex = Regex::new(r"[\p{L}0-9_][\p{L}0-9_\\-]*$").unwrap(); + static ref FORWARD_REGEX: Regex = Regex::new(r"^[\p{L}0-9_\\-]+").unwrap(); +} + +/// Given a line and cursor position, returns the start and end indices of the keyword +pub fn get_keyword_range(line: &str, col: usize, match_suffix: bool) -> (usize, usize) { + let before_match_start = BACKWARD_REGEX + .find(&line[0..col.min(line.len())]) + .map(|m| m.start()); + if !match_suffix { + return (before_match_start.unwrap_or(col), col); + } + + let after_match_end = FORWARD_REGEX + .find(&line[col.min(line.len())..]) + .map(|m| m.end() + col); + ( + before_match_start.unwrap_or(col), + after_match_end.unwrap_or(col), + ) +} + +/// Given a string, guesses the start and end indices in the line for the specific item +/// 1. Get the keyword range (alphanumeric, underscore, hyphen) on the line and end of the item +/// text +/// 2. Check if the suffix of the item text matches the suffix of the line text, if so, include the +/// suffix in the range +/// +/// Example: +/// line: example/str/trim +/// item: str/trim +/// matches on: str/trim +/// +/// line: example/trim +/// item: str/trim +/// matches on: trim +/// +/// TODO: +/// line: ' +/// item: 'tabline' +/// matches on: ' +pub fn guess_keyword_range_from_item( + item_text: &str, + line: &str, + cursor_col: usize, + match_suffix: bool, +) -> (usize, usize) { + let line_range = get_keyword_range(line, cursor_col, match_suffix); + let text_range = get_keyword_range(item_text, item_text.len(), false); + + let line_prefix = line.chars().take(line_range.0).collect::<String>(); + let text_prefix = item_text.chars().take(text_range.0).collect::<String>(); + if line_prefix.ends_with(&text_prefix) { + return (line_range.0 - text_prefix.len(), line_range.1); + } + + line_range +} + +pub fn guess_keyword_from_item( + item_text: &str, + line: &str, + cursor_col: usize, + match_suffix: bool, +) -> String { + let (start, end) = guess_keyword_range_from_item(item_text, line, cursor_col, match_suffix); + line[start..end].to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_keyword_range_unicode() { + let line = "'вest'"; + let col = line.len() - 1; + assert_eq!(get_keyword_range(line, col, false), (1, line.len() - 1)); + } +} diff --git a/lua/blink/cmp/fuzzy/lib.rs b/lua/blink/cmp/fuzzy/lib.rs new file mode 100644 index 0000000..99b74ad --- /dev/null +++ b/lua/blink/cmp/fuzzy/lib.rs @@ -0,0 +1,156 @@ +use crate::frecency::FrecencyTracker; +use crate::fuzzy::FuzzyOptions; +use crate::lsp_item::LspItem; +use lazy_static::lazy_static; +use mlua::prelude::*; +use regex::Regex; +use std::collections::{HashMap, HashSet}; +use std::sync::RwLock; + +mod frecency; +mod fuzzy; +mod keyword; +mod lsp_item; + +lazy_static! { + static ref REGEX: Regex = Regex::new(r"\p{L}[\p{L}0-9_\\-]{2,}").unwrap(); + static ref FRECENCY: RwLock<Option<FrecencyTracker>> = RwLock::new(None); + static ref HAYSTACKS_BY_PROVIDER: RwLock<HashMap<String, Vec<LspItem>>> = + RwLock::new(HashMap::new()); +} + +pub fn init_db(_: &Lua, (db_path, use_unsafe_no_lock): (String, bool)) -> LuaResult<bool> { + let mut frecency = FRECENCY.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string()) + })?; + if frecency.is_some() { + return Ok(false); + } + *frecency = Some(FrecencyTracker::new(&db_path, use_unsafe_no_lock)?); + Ok(true) +} + +pub fn destroy_db(_: &Lua, _: ()) -> LuaResult<bool> { + let frecency = FRECENCY.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string()) + })?; + drop(frecency); + + let mut frecency = FRECENCY.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string()) + })?; + *frecency = None; + + Ok(true) +} + +pub fn access(_: &Lua, item: LspItem) -> LuaResult<bool> { + let mut frecency_handle = FRECENCY.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string()) + })?; + let frecency = frecency_handle.as_mut().ok_or_else(|| { + mlua::Error::RuntimeError("Attempted to use frencecy before initialization".to_string()) + })?; + frecency.access(&item)?; + Ok(true) +} + +pub fn set_provider_items( + _: &Lua, + (provider_id, items): (String, Vec<LspItem>), +) -> LuaResult<bool> { + let mut items_by_provider = HAYSTACKS_BY_PROVIDER.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for items by provider".to_string()) + })?; + items_by_provider.insert(provider_id, items); + Ok(true) +} + +pub fn fuzzy( + _lua: &Lua, + (line, cursor_col, provider_id, opts): (String, usize, String, FuzzyOptions), +) -> LuaResult<(Vec<i32>, Vec<u32>)> { + let mut frecency_handle = FRECENCY.write().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for frecency".to_string()) + })?; + let frecency = frecency_handle.as_mut().ok_or_else(|| { + mlua::Error::RuntimeError("Attempted to use frencecy before initialization".to_string()) + })?; + + let haystacks_by_provider = HAYSTACKS_BY_PROVIDER.read().map_err(|_| { + mlua::Error::RuntimeError("Failed to acquire lock for items by provider".to_string()) + })?; + let haystack = haystacks_by_provider.get(&provider_id).ok_or_else(|| { + mlua::Error::RuntimeError(format!( + "Attempted to fuzzy match for provider {} before setting the provider's items", + provider_id + )) + })?; + + Ok(fuzzy::fuzzy(&line, cursor_col, haystack, frecency, opts)) +} + +pub fn fuzzy_matched_indices( + _lua: &Lua, + (line, cursor_col, haystack, match_suffix): (String, usize, Vec<String>, bool), +) -> LuaResult<Vec<Vec<usize>>> { + Ok(fuzzy::fuzzy_matched_indices( + &line, + cursor_col, + &haystack, + match_suffix, + )) +} + +pub fn get_keyword_range( + _lua: &Lua, + (line, col, match_suffix): (String, usize, bool), +) -> LuaResult<(usize, usize)> { + Ok(keyword::get_keyword_range(&line, col, match_suffix)) +} + +pub fn guess_edit_range( + _lua: &Lua, + (item, line, cursor_col, match_suffix): (LspItem, String, usize, bool), +) -> LuaResult<(usize, usize)> { + // TODO: take the max range from insert_text and filter_text + Ok(keyword::guess_keyword_range_from_item( + item.insert_text.as_ref().unwrap_or(&item.label), + &line, + cursor_col, + match_suffix, + )) +} + +pub fn get_words(_: &Lua, text: String) -> LuaResult<Vec<String>> { + Ok(REGEX + .find_iter(&text) + .map(|m| m.as_str().to_string()) + .filter(|s| s.len() < 512) + .collect::<HashSet<String>>() + .into_iter() + .collect()) +} + +// NOTE: skip_memory_check greatly improves performance +// https://github.com/mlua-rs/mlua/issues/318 +#[mlua::lua_module(skip_memory_check)] +fn blink_cmp_fuzzy(lua: &Lua) -> LuaResult<LuaTable> { + let exports = lua.create_table()?; + exports.set("init_db", lua.create_function(init_db)?)?; + exports.set("destroy_db", lua.create_function(destroy_db)?)?; + exports.set("access", lua.create_function(access)?)?; + exports.set( + "set_provider_items", + lua.create_function(set_provider_items)?, + )?; + exports.set("fuzzy", lua.create_function(fuzzy)?)?; + exports.set( + "fuzzy_matched_indices", + lua.create_function(fuzzy_matched_indices)?, + )?; + exports.set("get_keyword_range", lua.create_function(get_keyword_range)?)?; + exports.set("guess_edit_range", lua.create_function(guess_edit_range)?)?; + exports.set("get_words", lua.create_function(get_words)?)?; + Ok(exports) +} diff --git a/lua/blink/cmp/fuzzy/lsp_item.rs b/lua/blink/cmp/fuzzy/lsp_item.rs new file mode 100644 index 0000000..a24669e --- /dev/null +++ b/lua/blink/cmp/fuzzy/lsp_item.rs @@ -0,0 +1,46 @@ +use mlua::prelude::*; + +#[derive(Debug)] +pub struct LspItem { + pub label: String, + pub filter_text: Option<String>, + pub sort_text: Option<String>, + pub insert_text: Option<String>, + pub kind: u32, + pub score_offset: i32, + pub source_id: String, +} + +impl FromLua for LspItem { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> { + if let Some(tab) = value.as_table() { + let label = tab.get("label").unwrap_or_default(); + let filter_text = tab.get("filterText").ok(); + let sort_text = tab.get("sortText").ok(); + let insert_text = tab + .get::<LuaTable>("textEdit") + .and_then(|text_edit| text_edit.get("newText")) + .ok() + .or_else(|| tab.get("insertText").ok()); + let kind = tab.get("kind").unwrap_or_default(); + let score_offset = tab.get("score_offset").unwrap_or(0); + let source_id = tab.get("source_id").unwrap_or_default(); + + Ok(LspItem { + label, + filter_text, + sort_text, + insert_text, + kind, + score_offset, + source_id, + }) + } else { + Err(mlua::Error::FromLuaConversionError { + from: "LuaValue", + to: "LspItem".to_string(), + message: None, + }) + } + } +} diff --git a/lua/blink/cmp/fuzzy/rust.lua b/lua/blink/cmp/fuzzy/rust.lua new file mode 100644 index 0000000..e2374cf --- /dev/null +++ b/lua/blink/cmp/fuzzy/rust.lua @@ -0,0 +1,20 @@ +--- @return string +local function get_lib_extension() + if jit.os:lower() == 'mac' or jit.os:lower() == 'osx' then return '.dylib' end + if jit.os:lower() == 'windows' then return '.dll' end + return '.so' +end + +-- search for the lib in the /target/release directory with and without the lib prefix +-- since MSVC doesn't include the prefix +package.cpath = package.cpath + .. ';' + .. debug.getinfo(1).source:match('@?(.*/)') + .. '../../../../target/release/lib?' + .. get_lib_extension() + .. ';' + .. debug.getinfo(1).source:match('@?(.*/)') + .. '../../../../target/release/?' + .. get_lib_extension() + +return require('blink_cmp_fuzzy') diff --git a/lua/blink/cmp/fuzzy/sort.lua b/lua/blink/cmp/fuzzy/sort.lua new file mode 100644 index 0000000..ec32ac3 --- /dev/null +++ b/lua/blink/cmp/fuzzy/sort.lua @@ -0,0 +1,48 @@ +local sort = {} + +--- @param list blink.cmp.CompletionItem[] +--- @param funcs ("label" | "sort_text" | "kind" | "score" | blink.cmp.SortFunction)[] +--- @return blink.cmp.CompletionItem[] +function sort.sort(list, funcs) + local sorting_funcs = vim.tbl_map( + function(name_or_func) return type(name_or_func) == 'string' and sort[name_or_func] or name_or_func end, + funcs + ) + table.sort(list, function(a, b) + for _, sorting_func in ipairs(sorting_funcs) do + local result = sorting_func(a, b) + if result ~= nil then return result end + end + end) + return list +end + +function sort.score(a, b) + if a.score == b.score then return end + return a.score > b.score +end + +function sort.kind(a, b) + if a.kind == b.kind then return end + return a.kind < b.kind +end + +function sort.sort_text(a, b) + if a.sortText == b.sortText or a.sortText == nil or b.sortText == nil then return end + return a.sortText < b.sortText +end + +function sort.label(a, b) + local _, entry1_under = a.label:find('^_+') + local _, entry2_under = b.label:find('^_+') + entry1_under = entry1_under or 0 + entry2_under = entry2_under or 0 + if entry1_under > entry2_under then + return false + elseif entry1_under < entry2_under then + return true + end + return a.label < b.label +end + +return sort |
