Paper: Optuna: A Next-generation Hyperparameter Optimization Framework – Akiba et al 2019
Hyperparameter Search With Optuna: Part 1 – Scikit-learn Classification and Ensembling
Hyperparameter Search With Optuna: Part 2 – XGBoost Classification and Ensembling
Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling
Hyperparameter Search (And Pruning) With Optuna: Part 4 – XGBoost Classification and Ensembling
1. Introduction
In addition to using the tree-structured Parzen algorithm via Optuna to find hyperparameters for a CNN with Keras for the the MNIST handwritten digits data set classification problem, we add asynchronous successive halving, a pruning algorithm, to halt training when preliminary results are unpromising.
2. Asynchronous Successive Halving
Successive Halving is a bandit-based algorithm to identify the best one among multiple configurations. This class implements an asynchronous version of Successive Halving. Please refer to the paper of Asynchronous Successive Halving for detailed descriptions.
As applied to a CNN, this means that after a certain number of epochs, if the error metric does not meet a threshold (see the references above), training is stopped (pruned) for that set of hyperparameters. This strategy allows Optuna to sample a greater number of sets of hyperparameters in a given amount of computation time.
To implement pruning, we make the following changes to the code used in Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling.
First, we import:
from optuna.integration import KerasPruningCallback
Next, in class Objective(object), we add a new callback to the callbacks list:
callbacks_list = [EarlyStopping(monitor='val_loss', patience=self.early_stop), ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=self.learn_rate_epochs, verbose=0, mode='auto', min_lr=1.0e-6), ModelCheckpoint(filepath=fn, monitor='val_loss', save_best_only=True), KerasPruningCallback(trial, 'val_loss')]
Finally, in the main code we have:
study = optuna.create_study(direction=optimizer_direction, sampler=TPESampler(n_startup_trials=number_of_random_points), pruner=optuna.pruners.SuccessiveHalvingPruner(min_resource='auto', reduction_factor=4, min_early_stopping_rate=0))
‘auto’ means that Optuna uses a heuristic to determine the number of boosting rounds to perform before deciding to enact pruning. This is based on the first number of trials that are run to completion. Alternatively, this parameter can be explicitly set by the user. See the Optuna documentation for definitions of the other parameters.
Finally, in the make_final_predictions() function, the coding is identical to the non-pruned version, but we note that the DataFrame that Optuna uses to store trial results has a column labeled ‘state’ that we actually use. When pruning is enabled, the ‘state’ can have a value of COMPLETE or PRUNED. We only use those saved CNN models that ran to completion.
3. Results
Below is the DataFrame from the Optuna study. We sorted by the ‘value’ column (this is the validation loss) and only kept the 25 best results.
Optuna Pruned CNN Results DataFrame
number | value | params_batch_size | params_dense_nodes_divisor | params_drop_out | params_kernel_size | params_num_cnn_blocks | params_num_dense_nodes | params_num_filters | system_attrs__number | system_attrs_completed_rung_0 | system_attrs_completed_rung_1 | system_attrs_completed_rung_2 | state |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
79 | 0.0275506863 | 32 | 4 | 0.05 | 3 | 4 | 512 | 64 | 79 | 0.0458960985 | 0.0410975967 | 0.032053134 | COMPLETE |
51 | 0.0287339926 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 51 | 0.0529623508 | 0.0386455245 | 0.0294012611 | COMPLETE |
0 | 0.0289913212 | 128 | 4 | 0.15 | 2 | 4 | 128 | 48 | 0 | COMPLETE | |||
90 | 0.0291640001 | 32 | 4 | 0.05 | 3 | 4 | 512 | 48 | 90 | 0.0561855177 | 0.0393893917 | 0.0307672351 | COMPLETE |
91 | 0.0293711314 | 32 | 4 | 0.05 | 3 | 4 | 512 | 48 | 91 | 0.0489851257 | 0.0369895449 | 0.0297974418 | COMPLETE |
92 | 0.0301389344 | 32 | 4 | 0.05 | 3 | 4 | 512 | 48 | 92 | 0.0488571736 | 0.0400173946 | 0.0301389344 | COMPLETE |
55 | 0.0303010894 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 55 | 0.062571248 | 0.0446030207 | 0.033166684 | COMPLETE |
56 | 0.0322642285 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 56 | 0.0550761841 | 0.0418971828 | 0.03356296 | COMPLETE |
88 | 0.0332499588 | 32 | 4 | 0.05 | 3 | 4 | 512 | 64 | 88 | 0.0566676402 | 0.0436420241 | 0.0332499588 | COMPLETE |
48 | 0.0337947764 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 48 | 0.0641317141 | 0.0439486502 | 0.0368314922 | COMPLETE |
78 | 0.0346027303 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 78 | 0.0548431493 | 0.0360887715 | 0.0346027303 | PRUNED |
54 | 0.0346579468 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 54 | 0.0561321469 | 0.0453210574 | 0.0363167872 | COMPLETE |
80 | 0.0362471782 | 32 | 4 | 0.05 | 3 | 4 | 512 | 64 | 80 | 0.0526494365 | 0.0376668779 | 0.0362471782 | PRUNED |
27 | 0.0371584418 | 32 | 4 | 0.05 | 2 | 4 | 1024 | 64 | 27 | 0.0675621782 | 0.0418939642 | 0.0379544506 | COMPLETE |
4 | 0.0374717352 | 32 | 2 | 0.3 | 3 | 3 | 512 | 64 | 4 | 0.0636372029 | 0.0423244475 | 0.0381203787 | COMPLETE |
72 | 0.0379917268 | 32 | 4 | 0.05 | 3 | 4 | 512 | 64 | 72 | 0.0455215652 | 0.0407943215 | 0.0379917268 | PRUNED |
2 | 0.0420262766 | 96 | 4 | 0.15 | 2 | 3 | 1024 | 32 | 2 | 0.098249785 | 0.0520678783 | 0.0511838313 | COMPLETE |
43 | 0.0420647348 | 32 | 4 | 0.15 | 2 | 3 | 1024 | 48 | 43 | 0.067462517 | 0.0459376596 | 0.0420647348 | PRUNED |
57 | 0.0429445988 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 57 | 0.0606231133 | 0.0436749889 | 0.0429445988 | PRUNED |
35 | 0.043128442 | 32 | 4 | 0.2 | 2 | 3 | 512 | 64 | 35 | 0.0688032493 | 0.0447511463 | 0.043128442 | PRUNED |
84 | 0.0442503235 | 32 | 2 | 0.05 | 3 | 4 | 64 | 64 | 84 | 0.0544564733 | 0.0427888586 | 0.0442503235 | PRUNED |
83 | 0.0445057592 | 32 | 4 | 0.05 | 3 | 4 | 512 | 64 | 83 | 0.0572899426 | 0.0445057592 | PRUNED | |
58 | 0.0460092295 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 58 | 0.0624135739 | 0.0460092295 | PRUNED | |
82 | 0.0462410353 | 32 | 4 | 0.05 | 3 | 4 | 512 | 64 | 82 | 0.050994581 | 0.0462410353 | PRUNED | |
59 | 0.0467597407 | 32 | 4 | 0.05 | 2 | 4 | 512 | 64 | 59 | 0.0487022117 | 0.0467597407 | PRUNED |
To create the final result, we set a minimum loss threshold of 0.04 so that only the 13 best models that ran to completion were used in the ensemble. Then we averaged the resulting class probabilities and used plurality voting to obtain final class predictions.
balanced accuracy score = 0.99403
accuracy score = 0.9941
[NOTE: For some unknown reason, the confusion matrix images were cut off at the top and bottom. Below each image we have included the cut off values.]
Top left = 978, bottom right = 996.
Top left = 0.9939, bottom right = 0.9950.
Top left = 0.9980, bottom right = 0.9871.
4. Code
See Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling and the changes mentioned above.