Dieses Beispiel vergleicht die Parametersuche von HalvingGridSearchCV und GridSearchCV.

from time import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.experimental import enable_halving_search_cv  # noqa
from sklearn.model_selection import HalvingGridSearchCV


print(__doc__)

Wir definieren zunächst den Parameterraum für an SVC Schätzer und berechnen Sie die Zeit, die zum Trainieren von a . erforderlich ist HalvingGridSearchCV Instanz sowie als GridSearchCV Beispiel.

rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)

gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {'gamma': gammas, 'C': Cs}

clf = SVC(random_state=rng)

tic = time()
gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, factor=2,
                          random_state=rng)
gsh.fit(X, y)
gsh_time = time() - tic

tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic

Wir zeichnen jetzt Heatmaps für beide Suchschätzer.

def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
    """Helper to make a heatmap."""
    results = pd.DataFrame.from_dict(gs.cv_results_)
    results['params_str'] = results.params.apply(str)
    if is_sh:
        # SH dataframe: get mean_test_score values for the highest iter
        scores_matrix = results.sort_values('iter').pivot_table(
                index='param_gamma', columns='param_C',
                values='mean_test_score', aggfunc='last'
        )
    else:
        scores_matrix = results.pivot(index='param_gamma', columns='param_C',
                                      values='mean_test_score')

    im = ax.imshow(scores_matrix)

    ax.set_xticks(np.arange(len(Cs)))
    ax.set_xticklabels(['{:.0E}'.format(x) for x in Cs])
    ax.set_xlabel('C', fontsize=15)

    ax.set_yticks(np.arange(len(gammas)))
    ax.set_yticklabels(['{:.0E}'.format(x) for x in gammas])
    ax.set_ylabel('gamma', fontsize=15)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    if is_sh:
        iterations = results.pivot_table(index='param_gamma',
                                         columns='param_C', values='iter',
                                         aggfunc='max').values
        for i in range(len(gammas)):
            for j in range(len(Cs)):
                ax.text(j, i, iterations[i, j],
                        ha="center", va="center", color="w", fontsize=20)

    if make_cbar:
        fig.subplots_adjust(right=0.8)
        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        fig.colorbar(im, cax=cbar_ax)
        cbar_ax.set_ylabel('mean_test_score', rotation=-90, va="bottom",
                           fontsize=15)


fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes

make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)

ax1.set_title('Successive Halvingntime = {:.3f}s'.format(gsh_time),
              fontsize=15)
ax2.set_title('GridSearchntime = {:.3f}s'.format(gs_time), fontsize=15)

plt.show()

Aufeinanderfolgende Halbierungszeit = 2,731 s, GridSearch-Zeit = 13,964 s

Die Heatmaps zeigen das mittlere Testergebnis der Parameterkombinationen für an SVC Beispiel. Die HalvingGridSearchCV zeigt auch die Iteration, bei der die Kombinationen zuletzt verwendet wurden. Die mit gekennzeichneten Kombinationen 0 wurden nur bei der ersten Iteration ausgewertet, während diejenigen mit 5 sind die Parameterkombinationen, die als die besten angesehen werden.

Wir können sehen, dass die HalvingGridSearchCV Klasse ist in der Lage, Parameterkombinationen zu finden, die genauso genau sind wie GridSearchCV, in viel kürzerer Zeit.

Gesamtlaufzeit des Skripts: ( 0 Minuten 17.141 Sekunden)

Startordner

Download Python source code: plot_successive_halving_heatmap.py

Download Jupyter notebook: plot_successive_halving_heatmap.ipynb