Paper: Optuna: A Next-generation Hyperparameter Optimization Framework – Akiba et al 2019
Hyperparameter Search With Optuna: Part 1 – Scikit-learn Classification and Ensembling
Hyperparameter Search With Optuna: Part 2 – XGBoost Classification and Ensembling
Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling
Hyperparameter Search (And Pruning) With Optuna: Part 4 – XGBoost Classification and Ensembling

  1. Introduction
  2. Asynchronous Successive Halving
  3. Results
  4. Code

1. Introduction

In addition to using the tree-structured Parzen algorithm via Optuna to find hyperparameters for a CNN with Keras for the the MNIST handwritten digits data set classification problem, we add asynchronous successive halving, a pruning algorithm, to halt training when preliminary results are unpromising.

2. Asynchronous Successive Halving

Successive Halving is a bandit-based algorithm to identify the best one among multiple configurations. This class implements an asynchronous version of Successive Halving. Please refer to the paper of Asynchronous Successive Halving for detailed descriptions.

As applied to a CNN, this means that after a certain number of epochs, if the error metric does not meet a threshold (see the references above), training is stopped (pruned) for that set of hyperparameters. This strategy allows Optuna to sample a greater number of sets of hyperparameters in a given amount of computation time.

To implement pruning, we make the following changes to the code used in Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling.

First, we import:

from optuna.integration import KerasPruningCallback

Next, in class Objective(object), we add a new callback to the callbacks list:

callbacks_list = [EarlyStopping(monitor='val_loss', patience=self.early_stop),
                  ReduceLROnPlateau(monitor='val_loss', factor=0.1, 
                                    patience=self.learn_rate_epochs, 
                                    verbose=0, mode='auto', min_lr=1.0e-6),
                  ModelCheckpoint(filepath=fn, monitor='val_loss', 
                                  save_best_only=True),
                  KerasPruningCallback(trial, 'val_loss')]

Finally, in the main code we have:

study = optuna.create_study(direction=optimizer_direction,
                sampler=TPESampler(n_startup_trials=number_of_random_points),
                pruner=optuna.pruners.SuccessiveHalvingPruner(min_resource='auto', 
                       reduction_factor=4, min_early_stopping_rate=0))

‘auto’ means that Optuna uses a heuristic to determine the number of boosting rounds to perform before deciding to enact pruning. This is based on the first number of trials that are run to completion. Alternatively, this parameter can be explicitly set by the user. See the Optuna documentation for definitions of the other parameters.

Finally, in the make_final_predictions() function, the coding is identical to the non-pruned version, but we note that the DataFrame that Optuna uses to store trial results has a column labeled ‘state’ that we actually use. When pruning is enabled, the ‘state’ can have a value of COMPLETE or PRUNED. We only use those saved CNN models that ran to completion.

3. Results

Below is the DataFrame from the Optuna study. We sorted by the ‘value’ column (this is the validation loss) and only kept the 25 best results.

Optuna Pruned CNN Results DataFrame

numbervalueparams_batch_sizeparams_dense_nodes_divisorparams_drop_outparams_kernel_sizeparams_num_cnn_blocksparams_num_dense_nodesparams_num_filterssystem_attrs__numbersystem_attrs_completed_rung_0system_attrs_completed_rung_1system_attrs_completed_rung_2state
790.02755068633240.053451264790.04589609850.04109759670.032053134COMPLETE
510.02873399263240.052451264510.05296235080.03864552450.0294012611COMPLETE
00.028991321212840.1524128480COMPLETE
900.02916400013240.053451248900.05618551770.03938939170.0307672351COMPLETE
910.02937113143240.053451248910.04898512570.03698954490.0297974418COMPLETE
920.03013893443240.053451248920.04885717360.04001739460.0301389344COMPLETE
550.03030108943240.052451264550.0625712480.04460302070.033166684COMPLETE
560.03226422853240.052451264560.05507618410.04189718280.03356296COMPLETE
880.03324995883240.053451264880.05666764020.04364202410.0332499588COMPLETE
480.03379477643240.052451264480.06413171410.04394865020.0368314922COMPLETE
780.03460273033240.052451264780.05484314930.03608877150.0346027303PRUNED
540.03465794683240.052451264540.05613214690.04532105740.0363167872COMPLETE
800.03624717823240.053451264800.05264943650.03766687790.0362471782PRUNED
270.03715844183240.0524102464270.06756217820.04189396420.0379544506COMPLETE
40.03747173523220.3335126440.06363720290.04232444750.0381203787COMPLETE
720.03799172683240.053451264720.04552156520.04079432150.0379917268PRUNED
20.04202627669640.152310243220.0982497850.05206787830.0511838313COMPLETE
430.04206473483240.1523102448430.0674625170.04593765960.0420647348PRUNED
570.04294459883240.052451264570.06062311330.04367498890.0429445988PRUNED
350.0431284423240.22351264350.06880324930.04475114630.043128442PRUNED
840.04425032353220.05346464840.05445647330.04278885860.0442503235PRUNED
830.04450575923240.053451264830.05728994260.0445057592PRUNED
580.04600922953240.052451264580.06241357390.0460092295PRUNED
820.04624103533240.053451264820.0509945810.0462410353PRUNED
590.04675974073240.052451264590.04870221170.0467597407PRUNED

To create the final result, we set a minimum loss threshold of 0.04 so that only the 13 best models that ran to completion were used in the ensemble. Then we averaged the resulting class probabilities and used plurality voting to obtain final class predictions.

balanced accuracy score = 0.99403
accuracy score = 0.9941

[NOTE: For some unknown reason, the confusion matrix images were cut off at the top and bottom. Below each image we have included the cut off values.]


Top left = 978, bottom right = 996.

Top left = 0.9939, bottom right = 0.9950.

Top left = 0.9980, bottom right = 0.9871.

4. Code

See Hyperparameter Search With Optuna: Part 3 – Keras (CNN) Classification and Ensembling and the changes mentioned above.