XGBoost mostly combines a huge number of regression trees with a small learning rate. In this situation, trees added early are significant and trees added late are unimportant.
Vinayak and Gilad-Bachrach proposed a new method to add dropout techniques from the deep neural net community to boosted trees, and reported better results in some situations.
This is a instruction of new tree booster dart.
Rashmi Korlakai Vinayak, Ran Gilad-Bachrach. “DART: Dropouts meet Multiple Additive Regression Trees.” JMLR.
Because of the randomness introduced in the training, expect the following few differences:
gbtree because the random dropout prevents usage of the prediction buffer.The booster dart inherits gbtree booster, so it supports all parameters that gbtree does, such as eta, gamma, max_depth etc.
Additional parameters are noted below:
sample_type: type of sampling algorithm.
uniform: (default) dropped trees are selected uniformly.weighted: dropped trees are selected in proportion to weight.normalize_type: type of normalization algorithm.
tree: (default) New trees have the same weight of each of dropped trees.forest: New trees have the same weight of sum of dropped trees (forest).rate_drop: dropout rate.
skip_drop: probability of skipping dropout.
import xgboost as xgb
# read in data
dtrain = xgb.DMatrix('demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('demo/data/agaricus.txt.test')
# specify parameters via map
param = {'booster': 'dart',
         'max_depth': 5, 'learning_rate': 0.1,
         'objective': 'binary:logistic', 'silent': True,
         'sample_type': 'uniform',
         'normalize_type': 'tree',
         'rate_drop': 0.1,
         'skip_drop': 0.5}
num_round = 50
bst = xgb.train(param, dtrain, num_round)
# make prediction
# ntree_limit must not be 0
preds = bst.predict(dtest, ntree_limit=num_round)
Note
Specify ntree_limit when predicting with test sets
By default, bst.predict() will perform dropouts on trees. To obtain
correct results on test sets, disable dropouts by specifying
a nonzero value for ntree_limit.