Module gbdt::gradient_boost
source · [−]Expand description
This module implements the process of gradient boosting decision tree algorithm. This module depends on the following module:
-
gbdt::config::Config: Config is needed to configure the gbdt algorithm.
-
gbdt::decision_tree: DecisionTree is used for training and predicting.
-
rand: This standard module is used to randomly select the data or features used in a single iteration of training if the data_sample_ratio or feature_sample_ratio is less than 1.0 .
Example
use gbdt::config::Config;
use gbdt::gradient_boost::GBDT;
use gbdt::decision_tree::{Data, DataVec};
// set config for algorithm
let mut cfg = Config::new();
cfg.set_feature_size(3);
cfg.set_max_depth(2);
cfg.set_min_leaf_size(1);
cfg.set_loss("SquaredError");
cfg.set_iterations(2);
// initialize GBDT algorithm
let mut gbdt = GBDT::new(&cfg);
// setup training data
let data1 = Data::new_training_data (
vec![1.0, 2.0, 3.0],
1.0,
1.0,
None
);
let data2 = Data::new_training_data (
vec![1.1, 2.1, 3.1],
1.0,
1.0,
None
);
let data3 = Data::new_training_data (
vec![2.0, 2.0, 1.0],
1.0,
2.0,
None
);
let data4 = Data::new_training_data (
vec![2.0, 2.3, 1.2],
1.0,
0.0,
None
);
let mut training_data: DataVec = Vec::new();
training_data.push(data1.clone());
training_data.push(data2.clone());
training_data.push(data3.clone());
training_data.push(data4.clone());
// train the decision trees.
gbdt.fit(&mut training_data);
// setup the test data
let mut test_data: DataVec = Vec::new();
test_data.push(data1.clone());
test_data.push(data2.clone());
test_data.push(Data::new_test_data(
vec![2.0, 2.0, 1.0],
None));
test_data.push(Data::new_test_data(
vec![2.0, 2.3, 1.2],
None));
println!("{:?}", gbdt.predict(&test_data));
// output:
// [1.0, 1.0, 2.0, 0.0]
Structs
The gradient boosting decision tree.