Function rusty_machine::analysis::cross_validation::k_fold_validate
source · [−]pub fn k_fold_validate<M, S>(
model: &mut M,
inputs: &Matrix<f64>,
targets: &Matrix<f64>,
k: usize,
score: S
) -> LearningResult<Vec<f64>>where
S: Fn(&Matrix<f64>, &Matrix<f64>) -> f64,
M: SupModel<Matrix<f64>, Matrix<f64>>,
Expand description
Randomly splits the inputs into k ‘folds’. For each fold a model is trained using all inputs except for that fold, and tested on the data in the fold. Returns the scores for each fold.
Arguments
model
- Used to train and predict for each fold.inputs
- All input samples.targets
- All targets.k
- Number of folds to use.score
- Used to compare the outputs for each fold to the targets. Higher scores are better. See theanalysis::score
module for examples.
Examples
use rusty_machine::analysis::cross_validation::k_fold_validate;
use rusty_machine::analysis::score::row_accuracy;
use rusty_machine::learning::naive_bayes::{NaiveBayes, Bernoulli};
use rusty_machine::linalg::{BaseMatrix, Matrix};
let inputs = Matrix::new(3, 2, vec![1.0, 1.1,
5.2, 4.3,
6.2, 7.3]);
let targets = Matrix::new(3, 3, vec![1.0, 0.0, 0.0,
0.0, 0.0, 1.0,
0.0, 0.0, 1.0]);
let mut model = NaiveBayes::<Bernoulli>::new();
let accuracy_per_fold: Vec<f64> = k_fold_validate(
&mut model,
&inputs,
&targets,
3,
// Score each fold by the fraction of test samples where
// the model's prediction equals the target.
row_accuracy
).unwrap();