Paper: Optuna: A Next-generation Hyperparameter Optimization Framework – Akiba et al 2019
Hyperparameter Search With Optuna: Part 1 – Scikit-learn Classification and Ensembling
Hyperparameter Search With Optuna: Part 2 – XGBoost Classification and Ensembling
Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling
1. Introduction
In addition to using the tree-structured Parzen algorithm via Optuna to find hyperparameters for XGBoost for the the MNIST handwritten digits data set classification problem, we add asynchronous successive halving, a pruning algorithm, to halt training when preliminary results are unpromising.
2. Asynchronous Successive Halving
Successive Halving is a bandit-based algorithm to identify the best one among multiple configurations. This class implements an asynchronous version of Successive Halving. Please refer to the paper of Asynchronous Successive Halving for detailed descriptions.
As applied to XGBoost, this means that after a certain number of boosting rounds, if the error metric does not meet a threshold (see the references above), training is stopped (pruned) for that set of hyperparameters. This strategy allows Optuna to sample a greater number of sets of hyperparameters in a given amount of computation time.
To implement pruning, we make the following changes to the code used in Hyperparameter Search With Optuna: Part 2 – XGBoost Classification and Ensembling.
First, in class Objective(object), we add:
prune_error = 'eval-' + dictionary_single_params['eval_metric'] # prune_error = 'eval-mlogloss' pruning_callback = optuna.integration.XGBoostPruningCallback(trial, prune_error) xgb_model = xgb.train(params=dictionary_single_params, dtrain=self.dtrain, evals=watchlist, num_boost_round=self.maximum_boosting_rounds, early_stopping_rounds=self.early_stop_rounds, verbose_eval=False, callbacks=[pruning_callback])
Next, in the main code we have:
study = optuna.create_study(direction=optimizer_direction, sampler=TPESampler(n_startup_trials=number_of_random_points), pruner=optuna.pruners.SuccessiveHalvingPruner(min_resource='auto', reduction_factor=4, min_early_stopping_rate=0))
‘auto’ means that Optuna uses a heuristic to determine the number of boosting rounds to perform before deciding to enact pruning. This is based on the first number of trials that are run to completion. Alternatively, this parameter can be explicitly set by the user. See the Optuna documentation for definitions of the other parameters.
Finally, in the make_final_predictions() function, the coding is identical to the non-pruned version, but we note that the DataFrame that Optuna uses to store trial results has a column labeled ‘state’ that we actually use. When pruning is enabled, the ‘state’ can have a value of COMPLETE or PRUNED. We only use those saved XGBoost models that ran to completion.
3. Results
Below is the DataFrame from the Optuna study. We sorted by the ‘value’ column (this is the multiclass log loss) and only kept the 25 best results.
Optuna Results DataFrame
number | value | params_colsample_bytree | params_eta | params_max_bin | params_max_depth | params_reg_alpha | params_reg_lambda | params_subsample | system_attrs__number | system_attrs_completed_rung_0 | system_attrs_completed_rung_1 | system_attrs_completed_rung_2 | system_attrs_completed_rung_3 | state |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
615 | 0.1611 | 0.85 | 0.6607 | 25 | 14 | 0 | 3 | 0.9 | 615 | 0.4575 | 0.2357 | 0.1705 | COMPLETE | |
609 | 0.1616 | 0.85 | 0.6547 | 25 | 16 | 0 | 3 | 0.9 | 609 | 0.4572 | 0.2320 | 0.1690 | COMPLETE | |
278 | 0.1623 | 0.85 | 0.6469 | 25 | 12 | 0 | 3 | 0.9 | 278 | 0.4634 | 0.2386 | 0.1701 | COMPLETE | |
830 | 0.1626 | 0.85 | 0.6838 | 25 | 14 | 0 | 3 | 0.9 | 830 | 0.4495 | 0.2323 | 0.1710 | COMPLETE | |
1360 | 0.1627 | 0.85 | 0.6812 | 25 | 14 | 0 | 3 | 0.9 | 1360 | 0.4506 | 0.2353 | 0.1713 | COMPLETE | |
819 | 0.1631 | 0.85 | 0.6687 | 25 | 14 | 0 | 3 | 0.9 | 819 | 0.4539 | 0.2346 | 0.1705 | COMPLETE | |
789 | 0.1634 | 0.85 | 0.6811 | 25 | 14 | 0 | 3 | 0.9 | 789 | 0.4505 | 0.2354 | 0.1726 | 0.1640 | COMPLETE |
998 | 0.1634 | 0.85 | 0.6767 | 25 | 14 | 0 | 3 | 0.9 | 998 | 0.4512 | 0.2359 | 0.1730 | COMPLETE | |
309 | 0.1635 | 0.85 | 0.6787 | 25 | 11 | 0 | 3 | 0.9 | 309 | 0.4496 | 0.2332 | 0.1711 | COMPLETE | |
365 | 0.1639 | 0.85 | 0.6768 | 25 | 11 | 0 | 3 | 0.9 | 365 | 0.4504 | 0.2350 | 0.1716 | COMPLETE | |
835 | 0.1639 | 0.85 | 0.6823 | 25 | 14 | 0 | 3 | 0.9 | 835 | 0.4501 | 0.2348 | 0.1732 | COMPLETE | |
745 | 0.1640 | 0.85 | 0.6641 | 25 | 15 | 0 | 3 | 0.9 | 745 | 0.4545 | 0.2366 | 0.1738 | COMPLETE | |
596 | 0.1640 | 0.85 | 0.6730 | 25 | 14 | 0 | 3 | 0.9 | 596 | 0.4532 | 0.2357 | 0.1722 | COMPLETE | |
1261 | 0.1641 | 0.9 | 0.6837 | 25 | 14 | 0 | 3 | 0.9 | 1261 | 0.4509 | 0.2344 | 0.1730 | COMPLETE | |
277 | 0.1641 | 0.85 | 0.6513 | 25 | 12 | 0 | 3 | 0.9 | 277 | 0.4614 | 0.2380 | 0.1722 | COMPLETE | |
1098 | 0.1642 | 0.85 | 0.6788 | 25 | 14 | 0 | 3 | 0.9 | 1098 | 0.4506 | 0.2360 | 0.1712 | COMPLETE | |
623 | 0.1642 | 0.85 | 0.6761 | 25 | 14 | 0 | 3 | 0.9 | 623 | 0.4527 | 0.2367 | 0.1724 | COMPLETE | |
1149 | 0.1642 | 0.85 | 0.6935 | 25 | 11 | 0 | 3 | 0.9 | 1149 | 0.4434 | 0.2323 | 0.1734 | COMPLETE | |
656 | 0.1642 | 0.85 | 0.6672 | 25 | 20 | 0 | 3 | 0.9 | 656 | 0.4534 | 0.2337 | 0.1709 | COMPLETE | |
165 | 0.1643 | 0.85 | 0.6910 | 25 | 10 | 0 | 3 | 0.9 | 165 | 0.4444 | 0.2335 | 0.1729 | COMPLETE | |
276 | 0.1643 | 0.85 | 0.6590 | 25 | 12 | 0 | 3 | 0.9 | 276 | 0.4588 | 0.2379 | 0.1748 | COMPLETE | |
625 | 0.1643 | 0.85 | 0.6826 | 25 | 14 | 0 | 3 | 0.9 | 625 | 0.4499 | 0.2346 | 0.1717 | COMPLETE | |
593 | 0.1646 | 0.85 | 0.6769 | 25 | 13 | 0 | 3 | 0.9 | 593 | 0.4505 | 0.2374 | 0.1743 | COMPLETE | |
361 | 0.1646 | 0.55 | 0.7114 | 25 | 11 | 0 | 3 | 0.9 | 361 | 0.4459 | 0.2360 | 0.1746 | COMPLETE | |
460 | 0.1647 | 0.85 | 0.6690 | 25 | 13 | 0 | 3 | 0.9 | 460 | 0.4536 | 0.2359 | 0.1726 | COMPLETE |
To create the final result, we set a minimum loss threshold of 0.166 and only used the 25 best models that ran to completion. Then we averaged the resulting class probabilities and used plurality voting to obtain final class predictions.
balanced accuracy score = 0.9551
accuracy score = 0.9553



4. Code
See Hyperparameter Search With Optuna: Part 2 – XGBoost Classification and Ensembling and the changes mentioned above.