Module gbdt::decision_tree
source · [−]Expand description
This module implements a decision tree from the simple binary tree gbdt::binary_tree.
In the training process, the nodes are splited according impurity
.
Following hyperparameters are supported:
-
feature_size: the size of feautures. Training data and test data should have same feature_size. (default = 1)
-
max_depth: the max depth of the decision tree. The root node is considered to be in the layer 0. (default = 2)
-
min_leaf_size: the minimum number of samples required to be at a leaf node during training. (default = 1)
-
loss: the loss function type. SquaredError, LogLikelyhood and LAD are supported. See config::Loss. (default = SquareError).
-
feature_sample_ratio: portion of features to be splited. When spliting a node, a subset of the features (feature_size * feature_sample_ratio) will be randomly selected to calculate impurity. (default = 1.0)
Example
use gbdt::config::Loss;
use gbdt::decision_tree::{Data, DecisionTree, TrainingCache};
// set up training data
let data1 = Data::new_training_data(
vec![1.0, 2.0, 3.0],
1.0,
2.0,
None
);
let data2 = Data::new_training_data(
vec![1.1, 2.1, 3.1],
1.0,
1.0,
None
);
let data3 = Data::new_training_data(
vec![2.0, 2.0, 1.0],
1.0,
0.5,
None
);
let data4 = Data::new_training_data(
vec![2.0, 2.3, 1.2],
1.0,
3.0,
None,
);
let mut dv = Vec::new();
dv.push(data1.clone());
dv.push(data2.clone());
dv.push(data3.clone());
dv.push(data4.clone());
// train a decision tree
let mut tree = DecisionTree::new();
tree.set_feature_size(3);
tree.set_max_depth(2);
tree.set_min_leaf_size(1);
tree.set_loss(Loss::SquaredError);
let mut cache = TrainingCache::get_cache(3, &dv, 2);
tree.fit(&dv, &mut cache);
// set up the test data
let mut dv = Vec::new();
dv.push(data1.clone());
dv.push(data2.clone());
dv.push(Data::new_test_data(
vec![2.0, 2.0, 1.0],
None));
dv.push(Data::new_test_data(
vec![2.0, 2.3, 1.2],
Some(3.0)));
// inference the test data with the decision tree
println!("{:?}", tree.predict(&dv));
// output:
// [2.0, 0.75, 0.75, 3.0]
Structs
new_training_data
to generate a training sample, and call new_test_data
to generate a test sample.