Rata-rata ROC untuk validasi silang 10 kali lipat dengan estimasi probabilitas


15

Saya berencana untuk menggunakan validasi silang 10 kali lipat bertingkat yang diulang (sekitar 10 kali) pada sekitar 10.000 kasus menggunakan algoritma pembelajaran mesin. Setiap kali repetisi akan dilakukan dengan seed acak berbeda.

Dalam proses ini saya membuat 10 contoh estimasi probabilitas untuk setiap kasus. 1 instance estimasi probabilitas untuk masing-masing dari 10 pengulangan validasi silang 10 kali lipat

Bisakah saya rata-rata 10 probabilitas untuk setiap kasus dan kemudian membuat kurva ROC rata-rata baru (mewakili hasil CV 10 kali lipat berulang), yang dapat dibandingkan dengan kurva ROC lainnya dengan perbandingan berpasangan?

Jawaban:


13

Dari deskripsi Anda tampaknya masuk akal: tidak hanya Anda dapat menghitung kurva ROC rata-rata, tetapi juga varians di sekitarnya untuk membangun interval kepercayaan. Seharusnya memberi Anda gagasan tentang seberapa stabil model Anda.

Misalnya, seperti ini:

masukkan deskripsi gambar di sini

Di sini saya menempatkan kurva ROC individu serta kurva rata-rata dan interval kepercayaan. Ada area di mana kurva setuju, jadi kami memiliki varian kurang, dan ada area di mana mereka tidak setuju.

Untuk CV berulang, Anda bisa mengulanginya beberapa kali dan mendapatkan rata-rata total di semua lipatan individu:

masukkan deskripsi gambar di sini

Ini sangat mirip dengan gambar sebelumnya, tetapi memberikan perkiraan yang lebih stabil (yaitu dapat diandalkan) dari mean dan varians.

Berikut kode untuk mendapatkan plot:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

Untuk CV berulang:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Sumber inspirasi: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html


3

Itu tidak benar untuk probabilitas rata-rata karena itu tidak akan mewakili prediksi yang Anda coba validasi dan melibatkan kontaminasi antar sampel validasi.

Perhatikan bahwa 100 pengulangan validasi silang 10 kali lipat mungkin diperlukan untuk mencapai presisi yang memadai. Atau gunakan bootstrap optimisme Efron-Gong yang membutuhkan lebih sedikit iterasi untuk presisi yang sama (lihat misalnya fungsi rmspaket R validate).

c


Bisakah Anda menjelaskan lebih lanjut mengapa rata-rata tidak benar?
DataD'oh

Sudah dinyatakan. Anda perlu memvalidasi ukuran yang akan Anda gunakan di lapangan.
Frank Harrell
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.