1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//! Utility functions for random functionality.
//!
//! This module provides sampling and shuffling which are used
//! within the learning modules.

use rand::{Rng, thread_rng};

/// ```
/// use rusty_machine::learning::toolkit::rand_utils;
///
/// let mut pool = &mut [1,2,3,4];
/// let sample = rand_utils::reservoir_sample(pool, 3);
///
/// println!("{:?}", sample);
/// ```
pub fn reservoir_sample<T: Copy>(pool: &[T], reservoir_size: usize) -> Vec<T> {
    assert!(pool.len() >= reservoir_size,
            "Sample size is greater than total.");

    let mut pool_mut = &pool[..];

    let mut res = pool_mut[..reservoir_size].to_vec();
    pool_mut = &pool_mut[reservoir_size..];

    let mut ele_seen = reservoir_size;
    let mut rng = thread_rng();

    while !pool_mut.is_empty() {
        ele_seen += 1;
        let r = rng.gen_range(0..ele_seen);

        let p_0 = pool_mut[0];
        pool_mut = &pool_mut[1..];

        if r < reservoir_size {
            res[r] = p_0;
        }
    }

    res
}

/// The inside out Fisher-Yates algorithm.
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::toolkit::rand_utils;
///
/// // Collect the numbers 0..5
/// let a = (0..5).collect::<Vec<_>>();
///
/// // Perform a Fisher-Yates shuffle to get a random permutation
/// let permutation = rand_utils::fisher_yates(&a);
/// ```
pub fn fisher_yates<T: Copy>(arr: &[T]) -> Vec<T> {
    let n = arr.len();
    let mut rng = thread_rng();

    let mut shuffled_arr = Vec::with_capacity(n);

    unsafe {
        // We set the length here
        // We only access data which has been initialized in the algorithm
        shuffled_arr.set_len(n);
    }

    for i in 0..n {
        let j = rng.gen_range(0..i + 1);

        // If j isn't the last point in the active shuffled array
        if j != i {
            // Copy value at position j to the end of the shuffled array
            // This is safe as we only read initialized data (j < i)
            let x = shuffled_arr[j];
            shuffled_arr[i] = x;
        }

        // Place value at end of active array into shuffled array
        shuffled_arr[j] = arr[i];
    }

    shuffled_arr
}

/// The in place Fisher-Yates shuffle.
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::toolkit::rand_utils;
///
/// // Collect the numbers 0..5
/// let mut a = (0..5).collect::<Vec<_>>();
///
/// // Permute the values in place with Fisher-Yates
/// rand_utils::in_place_fisher_yates(&mut a);
/// ```
pub fn in_place_fisher_yates<T>(arr: &mut [T]) {
    let n = arr.len();
    let mut rng = thread_rng();

    for i in 0..n {
        // Swap i with a random point after it
        let j = rng.gen_range(0..n - i);
        arr.swap(i, i + j);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_reservoir_sample() {
        let a = vec![1, 2, 3, 4, 5, 6, 7];

        let b = reservoir_sample(&a, 3);

        assert_eq!(b.len(), 3);
    }

    #[test]
    fn test_fisher_yates() {
        let a = (0..10).collect::<Vec<_>>();

        let b = fisher_yates(&a);

        for val in a.iter() {
            assert!(b.contains(val));
        }
    }

    #[test]
    fn test_in_place_fisher_yates() {
        let mut a = (0..10).collect::<Vec<_>>();

        in_place_fisher_yates(&mut a);

        for val in 0..10 {
            assert!(a.contains(&val));
        }
    }
}