texcraft_stdext/algorithms/
spellcheck.rs

1//! Spell checking using Levenshtein distance
2//!
3//! This module contains a simple spell check facilty in the form a [find_close_words]
4//! function. This function accepts a word and a dictionary, which is a list of valid words.
5//! It find the words in the dictionary that are closest to the original world, where "closest"
6//! is defined using
7//! [Levenshtein distance](https://en.wikipedia.org/wiki/Levenshtein_distance).
8//!
9//! The result of the search is a list of [WordDiffs](WordDiff).
10//! This data structure describes how the initial word can be changed to the dictionary word using a combination
11//! of keep, add, subtract and modify [DiffOps](DiffOp).
12//!
13//! ## Implementation notes
14//!
15//! The Levenshtein distance calculation is implemented using dynamic programming and has the smallest possible
16//! space and time complexity (both O(n^2)). Note the space complexity could be O(n) if we only
17//! cared about the minimal distance between the strings, but we want to return the full [WordDiff].
18//! We use only O(n) space during the calculation, but each [WordDiff] is also O(n), so we end up with O(n^2).
19//!
20//! To see why we can only use O(n) space,
21//! let `n = a.len()` and `m = b.len()` and consider the `n x m `matrix `X` where `X[i][j]` is the comparison
22//! between `a[:i]` and `b[:j]`.
23//! The notation `a[:i]` means all characters in `a` between 0 and `i` inclusive.
24//! The solution is `X[n][m]`. The recursive relation is:
25//!
26//! ```text
27//! X[i][j] = {
28//!     X[i-1][j-1] if a[i] == b[j]  // append the (keep the character a[i]) op
29//!                                  //  to the best WordDiff for (a[:i-1], b[:j-1])
30//!     1 + min (
31//!         X[i-1][j],    // append the (subtract the character a[i]) op
32//!                       //  to the best WordDiff for (a[:i-1], b[j])
33//!         X[i][j-1],    // append the (add the character b[j]) op
34//!                       //  to the best WordDiff for (a[:i], b[j-1])
35//!         X[i-1][j-1],  // append the (modify the character a[i] to b[j]) op
36//!                       //  to the best WordDiff for (a[:i-1], b[:j-1])
37//!     ) otherwise
38//! }
39//! ```
40//!
41//! We calculate the matrix by iterating over the `j` variable first and then the `i` variable:
42//! ```
43//! # let a = Vec::<i64>::new();
44//! # let b = Vec::<i64>::new();
45//! for j in 0..(b.len() + 1) {
46//!     for i in 0..(a.len() + 1) {
47//!         // calculate X[i][j] using X[i-1][j], X[i][j-1] and X[i-1][j-1]
48//!     }
49//! }
50//! ```
51//!
52//! With this calculation order, we observe that `X[i][j]` only depends on the last `m+2` elements of `X` that have
53//! been calculated. Specifically,
54//!
55//! - `X[i-1][j]` was calculated in the previous iteration of the `i` loop, so `1` iteration before.
56//! - `X[i][j-1]` was calculated in the previous iteration of the `j` loop with the same `i` index, so `m+1` iterations before
57//!   because the `i` variable takes `m+1` values.
58//! - `X[i-1][j-1]` was calculated `m+2` iterations before.
59//!
60//! So, we don't need to store the full `X` matrix at all: we just need to store the last `m+2` elements that were calculated.
61//! We use a circular buffer of size `m+2` to do this.
62
63use crate::collections::circularbuffer::CircularBuffer;
64use std::ops::Index;
65
66/// Find words in the provided dictionary that are close to the search word.
67///
68/// The return value is an ordered list corresponding to every word in the dictionary, with
69/// the closest matches first.
70pub fn find_close_words(dictionary: &[&str], word: &str) -> Vec<WordDiff> {
71    // TODO: accept a generic iterator
72    let size_hint = dictionary.len();
73    //size_hint() {
74    //   (s, None) => s,
75    //   (_, Some(s)) => s,
76    //};
77    let mut comparisons = Vec::with_capacity(size_hint);
78    for valid_word in dictionary {
79        let comparison = levenshtein_distance(word, valid_word); // word);
80        comparisons.push(comparison);
81    }
82    comparisons.sort_by(|a, b| a.distance.cmp(&b.distance));
83    comparisons
84}
85
86fn levenshtein_distance(a: &str, b: &str) -> WordDiff {
87    let a: Vec<char> = a.chars().collect();
88    let b: Vec<char> = b.chars().collect();
89
90    let m = b.len();
91    let mut c = CircularBuffer::new(m + 2);
92    let idx_modify = m + 1;
93    let idx_subtract = m;
94    let idx_add = 0;
95
96    // This is comparing two empty strings - i.e., a[:0] and b[:0]
97    c.push(WordDiff {
98        distance: 0,
99        ops: Vec::with_capacity(std::cmp::max(a.len(), b.len())),
100    });
101
102    for b_j in &b {
103        // Here we are comparing an empty a string (i.e., a[:0]) with b[:j+1].
104        // There is only one possible action: append (add b[j]) to the diff for b[:j]
105        let cmp = c.clone_to_front(idx_add);
106        cmp.ops.push(DiffOp::Add(*b_j));
107        cmp.distance += 1;
108    }
109
110    for a_i in &a {
111        // Here we are comparing a[:i+1] with an empty b string (i.e., b[:0])
112        // There is only one possible action: append (subtract a[i]) to the diff for a[:i]
113        //let i_subtract = (idx_to_set + 1) % c.len();
114        let cmp = c.clone_to_front(idx_subtract);
115        cmp.ops.push(DiffOp::Subtract(*a_i));
116        cmp.distance += 1;
117
118        for b_j in &b {
119            // Here we are comparing a[:i+1] with a b[:j+1]
120            let (idx_to_clone, diff, distance_delta) = if a_i == b_j {
121                (idx_modify, DiffOp::Keep(*a_i), 0)
122            } else {
123                let cost_modify = c.index(idx_modify).distance;
124                let cost_add = c.index(idx_add).distance;
125                let cost_subtract = c.index(idx_subtract).distance;
126                if cost_modify <= std::cmp::min(cost_subtract, cost_add) {
127                    (idx_modify, DiffOp::Modify(*a_i, *b_j), 1)
128                } else if cost_subtract <= std::cmp::min(cost_modify, cost_add) {
129                    (idx_subtract, DiffOp::Subtract(*a_i), 1)
130                } else {
131                    (idx_add, DiffOp::Add(*b_j), 1)
132                }
133            };
134
135            let cmp = c.clone_to_front(idx_to_clone);
136            cmp.ops.push(diff);
137            cmp.distance += distance_delta;
138        }
139    }
140    c.index(0).clone()
141}
142
143#[derive(Debug, Eq, PartialEq, Copy, Clone)]
144pub enum DiffOp {
145    Keep(char),
146    Add(char),
147    Subtract(char),
148    Modify(char, char),
149}
150
151impl DiffOp {
152    #[cfg(test)]
153    fn invert(&self) -> DiffOp {
154        match self {
155            DiffOp::Keep(c) => DiffOp::Keep(*c),
156            DiffOp::Add(c) => DiffOp::Subtract(*c),
157            DiffOp::Subtract(c) => DiffOp::Add(*c),
158            DiffOp::Modify(a, b) => DiffOp::Modify(*b, *a),
159        }
160    }
161
162    #[cfg(test)]
163    fn distance(&self) -> usize {
164        match self {
165            DiffOp::Keep(_) => 0,
166            DiffOp::Add(_) => 1,
167            DiffOp::Subtract(_) => 1,
168            DiffOp::Modify(_, _) => 1,
169        }
170    }
171
172    fn left(&self) -> Option<char> {
173        match self {
174            DiffOp::Keep(c) => Some(*c),
175            DiffOp::Add(_) => None,
176            DiffOp::Subtract(c) => Some(*c),
177            DiffOp::Modify(c, _) => Some(*c),
178        }
179    }
180
181    fn right(&self) -> Option<char> {
182        match self {
183            DiffOp::Keep(c) => Some(*c),
184            DiffOp::Add(c) => Some(*c),
185            DiffOp::Subtract(_) => None,
186            DiffOp::Modify(_, c) => Some(*c),
187        }
188    }
189}
190
191#[derive(Debug, Eq, PartialEq)]
192pub struct WordDiff {
193    distance: usize,
194    ops: Vec<DiffOp>,
195}
196
197impl WordDiff {
198    pub fn left(&self) -> String {
199        self.ops.iter().filter_map(DiffOp::left).collect()
200    }
201
202    pub fn right(&self) -> String {
203        self.ops.iter().filter_map(DiffOp::right).collect()
204    }
205}
206
207impl Clone for WordDiff {
208    fn clone(&self) -> Self {
209        let mut cmp = WordDiff {
210            distance: self.distance,
211            ops: Vec::with_capacity(self.ops.capacity()),
212        };
213        cmp.ops.clone_from(&self.ops);
214        cmp
215    }
216
217    fn clone_from(&mut self, source: &Self) {
218        self.distance = source.distance;
219        self.ops.clear();
220        for diff in source.ops.iter() {
221            self.ops.push(*diff);
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    macro_rules! levenshtein_tests {
231        ($( ($name: ident, $a: expr, $b: expr, $i: expr),)+) => {
232            $(
233            #[test]
234            fn $name() {
235                let mut cmp = WordDiff {
236                    distance: 0,
237                    ops: $i,
238                };
239                for diff in cmp.ops.iter() {
240                    cmp.distance += diff.distance();
241                }
242                assert_eq![levenshtein_distance($a, $b), cmp];
243
244                let mut inverse_diffs = Vec::new();
245                for diff in cmp.ops.iter() {
246                    inverse_diffs.push(diff.invert());
247                }
248                let cmp = WordDiff {
249                    distance: cmp.distance,
250                    ops: inverse_diffs,
251                };
252                assert_eq![levenshtein_distance($b, $a), cmp];
253            }
254            )+
255        };
256    }
257
258    levenshtein_tests![
259        (case_1, "", "", Vec::<DiffOp>::new()),
260        (case_2, "a", "", vec![DiffOp::Subtract('a')]),
261        (case_3, "a", "a", vec![DiffOp::Keep('a')]),
262        (case_4, "a", "b", vec![DiffOp::Modify('a', 'b')]),
263        (
264            case_5,
265            "aa",
266            "a",
267            vec![DiffOp::Subtract('a'), DiffOp::Keep('a')]
268        ),
269        (
270            case_6,
271            "aa",
272            "ab",
273            vec![DiffOp::Keep('a'), DiffOp::Modify('a', 'b')]
274        ),
275        (
276            case_7,
277            "abb",
278            "acbb",
279            vec![
280                DiffOp::Keep('a'),
281                DiffOp::Add('c'),
282                DiffOp::Keep('b'),
283                DiffOp::Keep('b'),
284            ]
285        ),
286        (
287            case_8,
288            "aabb",
289            "abb",
290            vec![
291                DiffOp::Subtract('a'),
292                DiffOp::Keep('a'),
293                DiffOp::Keep('b'),
294                DiffOp::Keep('b'),
295            ]
296        ),
297        (
298            case_9,
299            "james",
300            "laura",
301            vec![
302                DiffOp::Modify('j', 'l'),
303                DiffOp::Keep('a'),
304                DiffOp::Modify('m', 'u'),
305                DiffOp::Modify('e', 'r'),
306                DiffOp::Modify('s', 'a'),
307            ]
308        ),
309        (
310            case_10,
311            "ab12345e",
312            "a12345de",
313            vec![
314                DiffOp::Keep('a'),
315                DiffOp::Subtract('b'),
316                DiffOp::Keep('1'),
317                DiffOp::Keep('2'),
318                DiffOp::Keep('3'),
319                DiffOp::Keep('4'),
320                DiffOp::Keep('5'),
321                DiffOp::Add('d'),
322                DiffOp::Keep('e'),
323            ]
324        ),
325    ];
326
327    #[test]
328    fn find_close_words_test() {
329        let dictionary = vec!["james", "laura", "mint"];
330        let word = "janes";
331        let result = find_close_words(&dictionary, &word);
332
333        assert_eq![result[0].right(), "james"];
334        assert_eq![result[1].right(), "laura"];
335        assert_eq![result[2].right(), "mint"];
336        assert_eq![result.len(), 3];
337    }
338}