Latihan Stratifikasi / Uji-split dalam scikit-learn


95

Saya perlu membagi data saya menjadi satu set pelatihan (75%) dan set pengujian (25%). Saat ini saya melakukannya dengan kode di bawah ini:

X, Xt, userInfo, userInfo_train = sklearn.cross_validation.train_test_split(X, userInfo)   

Namun, saya ingin membuat stratifikasi set data pelatihan saya. Bagaimana aku melakukan itu? Saya telah mempelajari StratifiedKFoldmetode ini, tetapi tidak membiarkan saya menentukan pembagian 75% / 25% dan hanya menyusun set data pelatihan.

Jawaban:


161

[pembaruan untuk 0,17]

Lihat dokumen dari sklearn.model_selection.train_test_split:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    stratify=y, 
                                                    test_size=0.25)

[/ perbarui untuk 0,17]

Ada permintaan tarik di sini . Tapi Anda bisa melakukan train, test = next(iter(StratifiedKFold(...))) dan menggunakan train dan test indeks jika Anda mau.


1
@AndreasMueller Adakah cara mudah untuk stratifikasi data regresi?
Yordania

3
@Jordan tidak ada yang diimplementasikan di scikit-learn. Saya tidak tahu cara standar. Kita bisa menggunakan persentil.
Andreas Mueller

@AndreasMueller Pernahkah Anda melihat perilaku di mana metode ini jauh lebih lambat daripada StratifiedShuffleSplit? Saya menggunakan kumpulan data MNIST.
Activatedgeek

@activatedgeek yang nampaknya sangat aneh, karena train_test_split (... stratify =) baru saja memanggil StratifiedShuffleSplit dan melakukan pemisahan pertama. Jangan ragu untuk membuka masalah di pelacak dengan contoh yang dapat direproduksi.
Andreas Mueller

@AndreasMueller Saya sebenarnya tidak membuka masalah karena saya memiliki perasaan yang kuat bahwa saya melakukan sesuatu yang salah (meskipun itu hanya 2 baris). Tetapi jika saya masih dapat mereproduksinya hari ini beberapa kali, saya akan melakukannya!
Activatedgeek

29

TL; DR: Gunakan StratifiedShuffleSplit dengantest_size=0.25

Scikit-learn menyediakan dua modul untuk Stratified Splitting:

  1. StratifiedKFold : Modul ini berguna sebagai operator validasi silang k-fold langsung: karena di dalamnya akan menyiapkan set n_foldspelatihan / pengujian sedemikian rupa sehingga kelas seimbang di keduanya.

Berikut beberapa kode (langsung dari dokumentasi di atas)

>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation
>>> len(skf)
2
>>> for train_index, test_index in skf:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
...    #fit and predict with X_train/test. Use accuracy metrics to check validation performance
  1. StratifiedShuffleSplit : Modul ini membuat satu set pelatihan / pengujian yang memiliki kelas yang seimbang (bertingkat). Pada dasarnya, inilah yang Anda inginkan dengan file n_iter=1. Anda dapat menyebutkan ukuran tes di sini sama seperti ditrain_test_split

Kode:

>>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
>>> len(sss)
1
>>> for train_index, test_index in sss:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
>>> # fit and predict with your classifier using the above X/y train/test

5
Perhatikan bahwa pada 0.18.x, n_iterseharusnya n_splitsuntuk StratifiedShuffleSplit - dan ada API yang sedikit berbeda untuk itu: scikit-learn.org/stable/modules/generated/…
lollercoaster

2
Jika yadalah Seri Pandas, gunakany.iloc[train_index], y.iloc[test_index]
Owlright

1
@Owlright Saya mencoba menggunakan pandas dataframe dan indeks yang dikembalikan StratifiedShuffleSplit bukanlah indeks di dataframe. dataframe index: 2,3,5 the first split in sss:[(array([2, 1]), array([0]))]:(
Meghna Natraj

2
@ Tangy mengapa ini loop for? bukankah itu kasus ketika sebuah baris X_train, X_test = X[train_index], X[test_index]dipanggil itu menimpa X_traindan X_test? Lalu mengapa tidak hanya satu next(sss)?
Bartek Wójcik


13

Berikut adalah contoh untuk data berkelanjutan / regresi (hingga masalah ini di GitHub diselesaikan).

min = np.amin(y)
max = np.amax(y)

# 5 bins may be too few for larger datasets.
bins     = np.linspace(start=min, stop=max, num=5)
y_binned = np.digitize(y, bins, right=True)

X_train, X_test, y_train, y_test = train_test_split(
    X, 
    y, 
    stratify=y_binned
)
  • Dimana startmin dan stopmaksimal target kontinu Anda.
  • Jika Anda tidak mengaturnya right=Truemaka itu akan lebih atau kurang membuat nilai maksimal Anda menjadi bin terpisah dan pemisahan Anda akan selalu gagal karena terlalu sedikit sampel akan berada di bin tambahan itu.

6

Selain jawaban yang diterima oleh @Andreas Mueller, hanya ingin menambahkannya sebagai @tangy yang disebutkan di atas:

StratifiedShuffleSplit paling mirip dengan train_test_split ( stratify = y) dengan fitur tambahan:

  1. stratifikasi secara default
  2. dengan menentukan n_splits , ini berulang kali membagi data

0
#train_size is 1 - tst_size - vld_size
tst_size=0.15
vld_size=0.15

X_train_test, X_valid, y_train_test, y_valid = train_test_split(df.drop(y, axis=1), df.y, test_size = vld_size, random_state=13903) 

X_train_test_V=pd.DataFrame(X_train_test)
X_valid=pd.DataFrame(X_valid)

X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=tst_size, random_state=13903)

0

Memperbarui jawaban @tangy dari atas ke versi scikit-learn: 0.23.2 ( dokumentasi StratifiedShuffleSplit ).

from sklearn.model_selection import StratifiedShuffleSplit

n_splits = 1  # We only want a single split in this case
sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=0)

for train_index, test_index in sss.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.