Optimization using Optuna¶
The following article presents an example of how to find the best hyperparameters to train a random forest using the optuna library.
In [10]:
import optuna
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
Load and split Data¶
In [4]:
# Load the iris dataset from Seaborn
iris = sns.load_dataset('iris')
# Split the dataset into features and labels
X = iris.drop('species', axis=1)
y = iris['species']
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
Optuna objetive function¶
In [5]:
def objective(trial):
"""
Objective function for Optuna.
"""
# Define the range of hyperparameters to search over
n_estimators = trial.suggest_int('n_estimators', 10, 100)
max_depth = trial.suggest_int('max_depth', 2, 10)
min_samples_split = trial.suggest_int('min_samples_split', 2, 10)
min_samples_leaf = trial.suggest_int('min_samples_leaf', 1, 10)
# Create a random forest classifier with the specified hyperparameters
clf = RandomForestClassifier(n_estimators=n_estimators,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf)
# Fit the classifier to the training data
clf.fit(X_train, y_train)
# Evaluate the classifier on the testing data
score = clf.score(X_test, y_test)
# Report the score back to Optuna
return score
Run study example¶
In [ ]:
# Define the Optuna study
study = optuna.create_study(direction='maximize')
# Run the study
study.optimize(objective, n_trials=100)
# Print the best hyperparameters found by Optuna
print(f"Best hyperparameters: {study.best_params}")
# Create a final classifier using the best hyperparameters
best_clf = RandomForestClassifier(**study.best_params)
# Fit the final classifier to the training data
best_clf.fit(X_train, y_train)
# Evaluate the final classifier on the testing data
final_score = best_clf.score(X_test, y_test)
Print best hyperparameters¶
In [9]:
print(f"Best hyperparameters: {study.best_params}")
Best hyperparameters: {'n_estimators': 94, 'max_depth': 9, 'min_samples_split': 9, 'min_samples_leaf': 1}