summaryrefslogtreecommitdiff
path: root/lua/telescope/algos/fzy.lua
diff options
context:
space:
mode:
authorswarn <swarn@users.noreply.github.com>2020-10-20 19:14:07 -0500
committerGitHub <noreply@github.com>2020-10-20 20:14:07 -0400
commit7eda4e80f9fa0b16b2030e81528f17bdaf118041 (patch)
tree28194d027fcbafcbea399f9bdd983975efc968a6 /lua/telescope/algos/fzy.lua
parent14f834b754844f9fba49b3c032753529553507fb (diff)
feat: Add a sorter using the fzy algorithm (#184)
* Add a sorter using the fzy algorithm * Reformat fzy.lua Also, update author attribution. * Remove constansts from fzy module Replace a few of the useful ones with getter functions that make it clear they're not modifiable. * Change names of fzy constant getters * fixup: some small nit picks Co-authored-by: TJ DeVries <devries.timothyj@gmail.com>
Diffstat (limited to 'lua/telescope/algos/fzy.lua')
-rw-r--r--lua/telescope/algos/fzy.lua191
1 files changed, 191 insertions, 0 deletions
diff --git a/lua/telescope/algos/fzy.lua b/lua/telescope/algos/fzy.lua
new file mode 100644
index 0000000..2b4b8ec
--- /dev/null
+++ b/lua/telescope/algos/fzy.lua
@@ -0,0 +1,191 @@
+-- The fzy matching algorithm
+--
+-- by Seth Warn <https://github.com/swarn>
+-- a lua port of John Hawthorn's fzy <https://github.com/jhawthorn/fzy>
+--
+-- > fzy tries to find the result the user intended. It does this by favouring
+-- > matches on consecutive letters and starts of words. This allows matching
+-- > using acronyms or different parts of the path." - J Hawthorn
+
+local path = require('telescope.path')
+
+local SCORE_GAP_LEADING = -0.005
+local SCORE_GAP_TRAILING = -0.005
+local SCORE_GAP_INNER = -0.01
+local SCORE_MATCH_CONSECUTIVE = 1.0
+local SCORE_MATCH_SLASH = 0.9
+local SCORE_MATCH_WORD = 0.8
+local SCORE_MATCH_CAPITAL = 0.7
+local SCORE_MATCH_DOT = 0.6
+local SCORE_MAX = math.huge
+local SCORE_MIN = -math.huge
+local MATCH_MAX_LENGTH = 1024
+
+local fzy = {}
+
+function fzy.has_match(needle, haystack)
+ needle = string.lower(needle)
+ haystack = string.lower(haystack)
+
+ local j = 1
+ for i = 1, string.len(needle) do
+ j = string.find(haystack, needle:sub(i, i), j, true)
+ if not j then
+ return false
+ else
+ j = j + 1
+ end
+ end
+
+ return true
+end
+
+local function is_lower(c)
+ return c:match("%l")
+end
+
+local function is_upper(c)
+ return c:match("%u")
+end
+
+local function precompute_bonus(haystack)
+ local match_bonus = {}
+
+ local last_char = path.separator
+ for i = 1, string.len(haystack) do
+ local this_char = haystack:sub(i, i)
+ if last_char == path.separator then
+ match_bonus[i] = SCORE_MATCH_SLASH
+ elseif last_char == "-" or last_char == "_" or last_char == " " then
+ match_bonus[i] = SCORE_MATCH_WORD
+ elseif last_char == "." then
+ match_bonus[i] = SCORE_MATCH_DOT
+ elseif is_lower(last_char) and is_upper(this_char) then
+ match_bonus[i] = SCORE_MATCH_CAPITAL
+ else
+ match_bonus[i] = 0
+ end
+
+ last_char = this_char
+ end
+
+ return match_bonus
+end
+
+local function compute(needle, haystack, D, M)
+ local match_bonus = precompute_bonus(haystack)
+ local n = string.len(needle)
+ local m = string.len(haystack)
+ local lower_needle = string.lower(needle)
+ local lower_haystack = string.lower(haystack)
+
+ -- Because lua only grants access to chars through substring extraction,
+ -- get all the characters from the haystack once now, to reuse below.
+ local haystack_chars = {}
+ for i = 1, m do
+ haystack_chars[i] = lower_haystack:sub(i, i)
+ end
+
+ for i=1,n do
+ D[i] = {}
+ M[i] = {}
+
+ local prev_score = SCORE_MIN
+ local gap_score = i == n and SCORE_GAP_TRAILING or SCORE_GAP_INNER
+ local needle_char = lower_needle:sub(i, i)
+
+ for j = 1, m do
+ if needle_char == haystack_chars[j] then
+ local score = SCORE_MIN
+ if i == 1 then
+ score = ((j - 1) * SCORE_GAP_LEADING) + match_bonus[j]
+ elseif j > 1 then
+ local a = M[i - 1][j - 1] + match_bonus[j]
+ local b = D[i - 1][j - 1] + SCORE_MATCH_CONSECUTIVE
+ score = math.max(a, b)
+ end
+ D[i][j] = score
+ prev_score = math.max(score, prev_score + gap_score)
+ M[i][j] = prev_score
+ else
+ D[i][j] = SCORE_MIN
+ prev_score = prev_score + gap_score
+ M[i][j] = prev_score
+ end
+ end
+ end
+end
+
+function fzy.score(needle, haystack)
+ local n = string.len(needle)
+ local m = string.len(haystack)
+
+ if n == 0 or m == 0 or m > MATCH_MAX_LENGTH or n > MATCH_MAX_LENGTH then
+ return SCORE_MIN
+ elseif n == m then
+ return SCORE_MAX
+ else
+ local D = {}
+ local M = {}
+ compute(needle, haystack, D, M)
+ return M[n][m]
+ end
+end
+
+function fzy.positions(needle, haystack)
+ local n = string.len(needle)
+ local m = string.len(haystack)
+
+ if n == 0 or m == 0 or m > MATCH_MAX_LENGTH or n > MATCH_MAX_LENGTH then
+ return {}
+ elseif n == m then
+ local consecutive = {}
+ for i=1,n do
+ consecutive[i] = i
+ end
+ return consecutive
+ end
+
+ local D = {}
+ local M = {}
+ compute(needle, haystack, D, M)
+
+ local positions = {}
+ local match_required = false
+ local j = m
+ for i=n,1,-1 do
+ while j >= 1 do
+ if D[i][j] ~= SCORE_MIN and (match_required or D[i][j] == M[i][j]) then
+ match_required = (i ~= 1) and (j ~= 1) and (
+ M[i][j] == D[i - 1][j - 1] + SCORE_MATCH_CONSECUTIVE)
+ positions[i] = j
+ j = j - 1
+ break
+ else
+ j = j - 1
+ end
+ end
+ end
+
+ return positions
+end
+
+-- If strings a or b are empty or too long, `fzy.score(a, b) == fzy.get_score_min()`.
+function fzy.get_score_min()
+ return SCORE_MIN
+end
+
+-- For exact matches, `fzy.score(s, s) == fzy.get_score_max()`.
+function fzy.get_score_max()
+ return SCORE_MAX
+end
+
+-- For all strings a and b that
+-- - are not covered by either `fzy.get_score_min()` or fzy.get_score_max()`, and
+-- - are matched, such that `fzy.has_match(a, b) == true`,
+-- then `fzy.score(a, b) > fzy.get_score_floor()` will be true.
+function fzy.get_score_floor()
+ return (MATCH_MAX_LENGTH + 1) * SCORE_GAP_INNER
+end
+
+return fzy