Hyperparameter optimization in neural networks is generally done heuristically, by varying each individual parameter such as learning rate, batch size and number of steps. Sklearn automates this by using the GridSearchCV [1]
Usually Sklearn’s examples and documentation is spot on and copy/pasting an example works with minimal changes. However this wasn’t quite the case with skflow and sklearn used in conjunction.
Hereunder is an example of using a tensorflow NeuralNetwork implemented in Skflow undergoing a hyperparameter optimization by using sklearn’s GridSearch:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
classifier = skflow.TensorFlowEstimator(model_fn=conv_model, n_classes=outer_name+1, | |
batch_size=10) | |
# use a full grid over all parameters | |
param_grid = {"steps": [1000, 1500, 2000, 2500, 3000], | |
"learning_rate": [0.01, 0.03, 0.05, 0.08], | |
"batch_size": [8, 10, 12]} | |
# run grid search | |
grid_search = GridSearchCV(classifier, param_grid=param_grid, scoring = 'accuracy', verbose=10, n_jobs=–1,cv=2) | |
grid_search.fit(X_train, y_train) | |
print(grid_search) | |
# summarize the results of the grid search | |
print(grid_search.best_score_) | |
print(grid_search.best_params_) |
In line 11, note that in contrast to the examples we normally run across the scoring method needs to be specified manually since Skflow doesn’t specify this intrinsically. Also note the n_jobs parameter is set to -1 to run N jobs in parallel, with N being the number of processors on your host. Last, note that the cv parameter is set to 2, meaning that the number of folds in the cross validation that GridSearch uses to judge accuracy is set to 2, rather than to the default 3. This last option obviously varies on a case by case basis
References
[1] http://scikit-learn.org/stable/modules/generated/sklearn.grid_search.GridSearchCV.html