Expand description

This module implements a decision tree from the simple binary tree gbdt::binary_tree.

In the training process, the nodes are splited according impurity.

Following hyperparameters are supported:

  1. feature_size: the size of feautures. Training data and test data should have same feature_size. (default = 1)

  2. max_depth: the max depth of the decision tree. The root node is considered to be in the layer 0. (default = 2)

  3. min_leaf_size: the minimum number of samples required to be at a leaf node during training. (default = 1)

  4. loss: the loss function type. SquaredError, LogLikelyhood and LAD are supported. See config::Loss. (default = SquareError).

  5. feature_sample_ratio: portion of features to be splited. When spliting a node, a subset of the features (feature_size * feature_sample_ratio) will be randomly selected to calculate impurity. (default = 1.0)

Example

use gbdt::config::Loss;
use gbdt::decision_tree::{Data, DecisionTree, TrainingCache};
// set up training data
let data1 = Data::new_training_data(
    vec![1.0, 2.0, 3.0],
    1.0,
    2.0,
    None
);
let data2 = Data::new_training_data(
    vec![1.1, 2.1, 3.1],
    1.0,
    1.0,
    None
);
let data3 = Data::new_training_data(
    vec![2.0, 2.0, 1.0],
    1.0,
    0.5,
    None
);
let data4 = Data::new_training_data(
    vec![2.0, 2.3, 1.2],
    1.0,
    3.0,
    None,
);

let mut dv = Vec::new();
dv.push(data1.clone());
dv.push(data2.clone());
dv.push(data3.clone());
dv.push(data4.clone());


// train a decision tree
let mut tree = DecisionTree::new();
tree.set_feature_size(3);
tree.set_max_depth(2);
tree.set_min_leaf_size(1);
tree.set_loss(Loss::SquaredError);
let mut cache = TrainingCache::get_cache(3, &dv, 2);
tree.fit(&dv, &mut cache);


// set up the test data
let mut dv = Vec::new();
dv.push(data1.clone());
dv.push(data2.clone());
dv.push(Data::new_test_data(
    vec![2.0, 2.0, 1.0],
    None));
dv.push(Data::new_test_data(
    vec![2.0, 2.3, 1.2],
    Some(3.0)));


// inference the test data with the decision tree
println!("{:?}", tree.predict(&dv));


// output:
// [2.0, 0.75, 0.75, 3.0]

Structs

A training sample or a test sample. You can call new_training_data to generate a training sample, and call new_test_data to generate a test sample.
The decision tree.
Cache the sort results and some calculation results

Constants

Type Definitions

The vector of the samples
The vector of the predicted values.