1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use linalg::{Matrix, BaseMatrix};
use learning::error::Error;
use super::{KNearest, KNearestSearch, get_distances, dist};
#[derive(Debug)]
pub struct BruteForce {
data: Option<Matrix<f64>>,
}
impl Default for BruteForce {
fn default() -> Self {
BruteForce {
data: None
}
}
}
impl BruteForce {
pub fn new() -> Self {
BruteForce::default()
}
}
impl KNearestSearch for BruteForce {
fn build(&mut self, data: Matrix<f64>) {
self.data = Some(data);
}
fn search(&self, point: &[f64], k: usize) -> Result<(Vec<usize>, Vec<f64>), Error> {
if let Some(ref data) = self.data {
let indices: Vec<usize> = (0..k).collect();
let distances = get_distances(data, point, &indices);
let mut query = KNearest::new(k, indices, distances);
let mut current_dist = query.dist();
let mut i = k;
for row in data.row_iter().skip(k) {
let d = dist(point, row.raw_slice());
if d < current_dist {
current_dist = query.add(i, d);
}
i += 1;
}
Ok(query.get_results())
} else {
Err(Error::new_untrained())
}
}
}
#[cfg(test)]
mod tests {
use linalg::Matrix;
use super::super::KNearestSearch;
use super::BruteForce;
#[test]
fn test_bruteforce_search() {
let m = Matrix::new(5, 2, vec![1., 2.,
8., 0.,
6., 10.,
3., 6.,
0., 3.]);
let mut b = BruteForce::new();
b.build(m);
let (ind, dist) = b.search(&vec![3., 4.9], 1).unwrap();
assert_eq!(ind, vec![3]);
assert_eq!(dist, vec![1.0999999999999996]);
let (ind, dist) = b.search(&vec![3., 4.9], 2).unwrap();
assert_eq!(ind, vec![3, 0]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076]);
let (ind, dist) = b.search(&vec![3., 4.9], 3).unwrap();
assert_eq!(ind, vec![3, 0, 4]);
assert_eq!(dist, vec![1.0999999999999996, 3.5227829907617076, 3.551056180912941]);
}
#[test]
fn test_bruteforce_untrained() {
let b = BruteForce::new();
let e = b.search(&vec![3., 4.9], 1);
assert!(e.is_err());
}
}