texcraft_stdext/algorithms/spellcheck.rs
1//! Spell checking using Levenshtein distance
2//!
3//! This module contains a simple spell check facilty in the form a [find_close_words]
4//! function. This function accepts a word and a dictionary, which is a list of valid words.
5//! It find the words in the dictionary that are closest to the original world, where "closest"
6//! is defined using
7//! [Levenshtein distance](https://en.wikipedia.org/wiki/Levenshtein_distance).
8//!
9//! The result of the search is a list of [WordDiffs](WordDiff).
10//! This data structure describes how the initial word can be changed to the dictionary word using a combination
11//! of keep, add, subtract and modify [DiffOps](DiffOp).
12//!
13//! ## Implementation notes
14//!
15//! The Levenshtein distance calculation is implemented using dynamic programming and has the smallest possible
16//! space and time complexity (both O(n^2)). Note the space complexity could be O(n) if we only
17//! cared about the minimal distance between the strings, but we want to return the full [WordDiff].
18//! We use only O(n) space during the calculation, but each [WordDiff] is also O(n), so we end up with O(n^2).
19//!
20//! To see why we can only use O(n) space,
21//! let `n = a.len()` and `m = b.len()` and consider the `n x m `matrix `X` where `X[i][j]` is the comparison
22//! between `a[:i]` and `b[:j]`.
23//! The notation `a[:i]` means all characters in `a` between 0 and `i` inclusive.
24//! The solution is `X[n][m]`. The recursive relation is:
25//!
26//! ```text
27//! X[i][j] = {
28//! X[i-1][j-1] if a[i] == b[j] // append the (keep the character a[i]) op
29//! // to the best WordDiff for (a[:i-1], b[:j-1])
30//! 1 + min (
31//! X[i-1][j], // append the (subtract the character a[i]) op
32//! // to the best WordDiff for (a[:i-1], b[j])
33//! X[i][j-1], // append the (add the character b[j]) op
34//! // to the best WordDiff for (a[:i], b[j-1])
35//! X[i-1][j-1], // append the (modify the character a[i] to b[j]) op
36//! // to the best WordDiff for (a[:i-1], b[:j-1])
37//! ) otherwise
38//! }
39//! ```
40//!
41//! We calculate the matrix by iterating over the `j` variable first and then the `i` variable:
42//! ```
43//! # let a = Vec::<i64>::new();
44//! # let b = Vec::<i64>::new();
45//! for j in 0..(b.len() + 1) {
46//! for i in 0..(a.len() + 1) {
47//! // calculate X[i][j] using X[i-1][j], X[i][j-1] and X[i-1][j-1]
48//! }
49//! }
50//! ```
51//!
52//! With this calculation order, we observe that `X[i][j]` only depends on the last `m+2` elements of `X` that have
53//! been calculated. Specifically,
54//!
55//! - `X[i-1][j]` was calculated in the previous iteration of the `i` loop, so `1` iteration before.
56//! - `X[i][j-1]` was calculated in the previous iteration of the `j` loop with the same `i` index, so `m+1` iterations before
57//! because the `i` variable takes `m+1` values.
58//! - `X[i-1][j-1]` was calculated `m+2` iterations before.
59//!
60//! So, we don't need to store the full `X` matrix at all: we just need to store the last `m+2` elements that were calculated.
61//! We use a circular buffer of size `m+2` to do this.
62
63use crate::collections::circularbuffer::CircularBuffer;
64use std::ops::Index;
65
66/// Find words in the provided dictionary that are close to the search word.
67///
68/// The return value is an ordered list corresponding to every word in the dictionary, with
69/// the closest matches first.
70pub fn find_close_words(dictionary: &[&str], word: &str) -> Vec<WordDiff> {
71 // TODO: accept a generic iterator
72 let size_hint = dictionary.len();
73 //size_hint() {
74 // (s, None) => s,
75 // (_, Some(s)) => s,
76 //};
77 let mut comparisons = Vec::with_capacity(size_hint);
78 for valid_word in dictionary {
79 let comparison = levenshtein_distance(word, valid_word); // word);
80 comparisons.push(comparison);
81 }
82 comparisons.sort_by(|a, b| a.distance.cmp(&b.distance));
83 comparisons
84}
85
86fn levenshtein_distance(a: &str, b: &str) -> WordDiff {
87 let a: Vec<char> = a.chars().collect();
88 let b: Vec<char> = b.chars().collect();
89
90 let m = b.len();
91 let mut c = CircularBuffer::new(m + 2);
92 let idx_modify = m + 1;
93 let idx_subtract = m;
94 let idx_add = 0;
95
96 // This is comparing two empty strings - i.e., a[:0] and b[:0]
97 c.push(WordDiff {
98 distance: 0,
99 ops: Vec::with_capacity(std::cmp::max(a.len(), b.len())),
100 });
101
102 for b_j in &b {
103 // Here we are comparing an empty a string (i.e., a[:0]) with b[:j+1].
104 // There is only one possible action: append (add b[j]) to the diff for b[:j]
105 let cmp = c.clone_to_front(idx_add);
106 cmp.ops.push(DiffOp::Add(*b_j));
107 cmp.distance += 1;
108 }
109
110 for a_i in &a {
111 // Here we are comparing a[:i+1] with an empty b string (i.e., b[:0])
112 // There is only one possible action: append (subtract a[i]) to the diff for a[:i]
113 //let i_subtract = (idx_to_set + 1) % c.len();
114 let cmp = c.clone_to_front(idx_subtract);
115 cmp.ops.push(DiffOp::Subtract(*a_i));
116 cmp.distance += 1;
117
118 for b_j in &b {
119 // Here we are comparing a[:i+1] with a b[:j+1]
120 let (idx_to_clone, diff, distance_delta) = if a_i == b_j {
121 (idx_modify, DiffOp::Keep(*a_i), 0)
122 } else {
123 let cost_modify = c.index(idx_modify).distance;
124 let cost_add = c.index(idx_add).distance;
125 let cost_subtract = c.index(idx_subtract).distance;
126 if cost_modify <= std::cmp::min(cost_subtract, cost_add) {
127 (idx_modify, DiffOp::Modify(*a_i, *b_j), 1)
128 } else if cost_subtract <= std::cmp::min(cost_modify, cost_add) {
129 (idx_subtract, DiffOp::Subtract(*a_i), 1)
130 } else {
131 (idx_add, DiffOp::Add(*b_j), 1)
132 }
133 };
134
135 let cmp = c.clone_to_front(idx_to_clone);
136 cmp.ops.push(diff);
137 cmp.distance += distance_delta;
138 }
139 }
140 c.index(0).clone()
141}
142
143#[derive(Debug, Eq, PartialEq, Copy, Clone)]
144pub enum DiffOp {
145 Keep(char),
146 Add(char),
147 Subtract(char),
148 Modify(char, char),
149}
150
151impl DiffOp {
152 #[cfg(test)]
153 fn invert(&self) -> DiffOp {
154 match self {
155 DiffOp::Keep(c) => DiffOp::Keep(*c),
156 DiffOp::Add(c) => DiffOp::Subtract(*c),
157 DiffOp::Subtract(c) => DiffOp::Add(*c),
158 DiffOp::Modify(a, b) => DiffOp::Modify(*b, *a),
159 }
160 }
161
162 #[cfg(test)]
163 fn distance(&self) -> usize {
164 match self {
165 DiffOp::Keep(_) => 0,
166 DiffOp::Add(_) => 1,
167 DiffOp::Subtract(_) => 1,
168 DiffOp::Modify(_, _) => 1,
169 }
170 }
171
172 fn left(&self) -> Option<char> {
173 match self {
174 DiffOp::Keep(c) => Some(*c),
175 DiffOp::Add(_) => None,
176 DiffOp::Subtract(c) => Some(*c),
177 DiffOp::Modify(c, _) => Some(*c),
178 }
179 }
180
181 fn right(&self) -> Option<char> {
182 match self {
183 DiffOp::Keep(c) => Some(*c),
184 DiffOp::Add(c) => Some(*c),
185 DiffOp::Subtract(_) => None,
186 DiffOp::Modify(_, c) => Some(*c),
187 }
188 }
189}
190
191#[derive(Debug, Eq, PartialEq)]
192pub struct WordDiff {
193 distance: usize,
194 ops: Vec<DiffOp>,
195}
196
197impl WordDiff {
198 pub fn left(&self) -> String {
199 self.ops.iter().filter_map(DiffOp::left).collect()
200 }
201
202 pub fn right(&self) -> String {
203 self.ops.iter().filter_map(DiffOp::right).collect()
204 }
205}
206
207impl Clone for WordDiff {
208 fn clone(&self) -> Self {
209 let mut cmp = WordDiff {
210 distance: self.distance,
211 ops: Vec::with_capacity(self.ops.capacity()),
212 };
213 cmp.ops.clone_from(&self.ops);
214 cmp
215 }
216
217 fn clone_from(&mut self, source: &Self) {
218 self.distance = source.distance;
219 self.ops.clear();
220 for diff in source.ops.iter() {
221 self.ops.push(*diff);
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 macro_rules! levenshtein_tests {
231 ($( ($name: ident, $a: expr, $b: expr, $i: expr),)+) => {
232 $(
233 #[test]
234 fn $name() {
235 let mut cmp = WordDiff {
236 distance: 0,
237 ops: $i,
238 };
239 for diff in cmp.ops.iter() {
240 cmp.distance += diff.distance();
241 }
242 assert_eq![levenshtein_distance($a, $b), cmp];
243
244 let mut inverse_diffs = Vec::new();
245 for diff in cmp.ops.iter() {
246 inverse_diffs.push(diff.invert());
247 }
248 let cmp = WordDiff {
249 distance: cmp.distance,
250 ops: inverse_diffs,
251 };
252 assert_eq![levenshtein_distance($b, $a), cmp];
253 }
254 )+
255 };
256 }
257
258 levenshtein_tests![
259 (case_1, "", "", Vec::<DiffOp>::new()),
260 (case_2, "a", "", vec![DiffOp::Subtract('a')]),
261 (case_3, "a", "a", vec![DiffOp::Keep('a')]),
262 (case_4, "a", "b", vec![DiffOp::Modify('a', 'b')]),
263 (
264 case_5,
265 "aa",
266 "a",
267 vec![DiffOp::Subtract('a'), DiffOp::Keep('a')]
268 ),
269 (
270 case_6,
271 "aa",
272 "ab",
273 vec![DiffOp::Keep('a'), DiffOp::Modify('a', 'b')]
274 ),
275 (
276 case_7,
277 "abb",
278 "acbb",
279 vec![
280 DiffOp::Keep('a'),
281 DiffOp::Add('c'),
282 DiffOp::Keep('b'),
283 DiffOp::Keep('b'),
284 ]
285 ),
286 (
287 case_8,
288 "aabb",
289 "abb",
290 vec![
291 DiffOp::Subtract('a'),
292 DiffOp::Keep('a'),
293 DiffOp::Keep('b'),
294 DiffOp::Keep('b'),
295 ]
296 ),
297 (
298 case_9,
299 "james",
300 "laura",
301 vec![
302 DiffOp::Modify('j', 'l'),
303 DiffOp::Keep('a'),
304 DiffOp::Modify('m', 'u'),
305 DiffOp::Modify('e', 'r'),
306 DiffOp::Modify('s', 'a'),
307 ]
308 ),
309 (
310 case_10,
311 "ab12345e",
312 "a12345de",
313 vec![
314 DiffOp::Keep('a'),
315 DiffOp::Subtract('b'),
316 DiffOp::Keep('1'),
317 DiffOp::Keep('2'),
318 DiffOp::Keep('3'),
319 DiffOp::Keep('4'),
320 DiffOp::Keep('5'),
321 DiffOp::Add('d'),
322 DiffOp::Keep('e'),
323 ]
324 ),
325 ];
326
327 #[test]
328 fn find_close_words_test() {
329 let dictionary = vec!["james", "laura", "mint"];
330 let word = "janes";
331 let result = find_close_words(&dictionary, &word);
332
333 assert_eq![result[0].right(), "james"];
334 assert_eq![result[1].right(), "laura"];
335 assert_eq![result[2].right(), "mint"];
336 assert_eq![result.len(), 3];
337 }
338}