Parameter "stratify" dari metode "train_test_split" (scikit Learn)


94

Saya mencoba menggunakan train_test_splitdari paket scikit Learn, tetapi saya mengalami masalah dengan parameter stratify. Selanjutnya kodenya:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

Namun, saya terus mendapatkan masalah berikut:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Apakah seseorang tahu apa yang sedang terjadi? Di bawah ini adalah dokumentasi fungsinya.

[...]

stratify : array-like atau None (default-nya adalah None)

Jika bukan None, data akan dibagi secara bertingkat, menggunakan ini sebagai larik label.

Baru di versi 0.17: stratify splitting

[...]


Tidak, semua terpecahkan.
Daneel Olivaw

Jawaban:


58

Scikit-Learn hanya memberi tahu Anda bahwa ia tidak mengenali argumen "bertingkat", bukan karena Anda menggunakannya secara tidak benar. Ini karena parameter ditambahkan pada versi 0.17 seperti yang ditunjukkan dalam dokumentasi yang Anda kutip.

Jadi, Anda hanya perlu memperbarui Scikit-Learn.


Saya mendapatkan kesalahan yang sama, meskipun saya memiliki versi 0.21.2 dari scikit-learn. scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forge
Kareem Jeiroudi

326

stratifyParameter ini melakukan pemisahan sehingga proporsi nilai dalam sampel yang dihasilkan akan sama dengan proporsi nilai yang diberikan pada parameter stratify.

Misalnya, jika variabel yadalah variabel kategorikal biner dengan nilai 0dan 1terdapat 25% dari nol dan 75% dari satu, stratify=yakan memastikan bahwa pemisahan acak Anda memiliki 25% dari 0dan 75% dari 1.


117
Ini tidak benar-benar menjawab pertanyaan tetapi sangat berguna untuk hanya memahami cara kerjanya. Terima kasih banyak.
Reed Jessen

6
Saya masih kesulitan untuk memahami, mengapa stratifikasi ini diperlukan: Jika ada kelas yang tidak seimbang dalam data, bukankah itu akan dipertahankan rata-rata saat melakukan pemisahan data secara acak?
Holger Brandl

14
@HolgerBrandl itu akan dipertahankan rata-rata; dengan stratifikasi, itu pasti akan dipertahankan.
Yonatan

7
@HolgerBrandl dengan kumpulan data yang sangat kecil atau sangat tidak seimbang, sangat mungkin pemisahan acak dapat sepenuhnya menghilangkan kelas dari salah satu pemisahan.
cddt

1
@HolgerBrandl Pertanyaan bagus! Mungkin kita bisa menambahkan itu dulu, Anda harus membagi menjadi set pelatihan dan pengujian menggunakan stratify. Kemudian kedua, untuk memperbaiki ketidakseimbangan, Anda akhirnya perlu menjalankan oversampling atau undersampling di set pelatihan. Banyak pengklasifikasi Sklearn memiliki parameter yang disebut bobot kelas yang dapat Anda atur menjadi seimbang. Terakhir, Anda juga dapat menggunakan metrik yang lebih sesuai daripada akurasi untuk set data yang tidak seimbang. Coba, F1 atau area di bawah ROC.
Claude DAPAT

62

Untuk masa depan saya yang datang ke sini melalui Google:

train_test_splitsekarang masuk model_selection, maka:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

adalah cara menggunakannya. Pengaturan random_statediinginkan untuk reproduktifitas.


Ini harus menjadi jawabannya :) Terima kasih
SwimBikeRun

15

Dalam konteks ini, stratifikasi berarti bahwa metode train_test_split mengembalikan subset pelatihan dan pengujian yang memiliki proporsi label kelas yang sama dengan kumpulan data masukan.


3

Coba jalankan kode ini, "berfungsi":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])

@ user5767535 Seperti yang Anda lihat, ini berfungsi pada mesin Ubuntu saya, dengan sklearnversi '0.17', distribusi Anaconda untuk Python 3,5. Saya hanya dapat menyarankan untuk memeriksa sekali lagi jika Anda memasukkan kode dengan benar dan memperbarui perangkat lunak Anda.
Sergey Bushmanov

2
@ user5767535 BTW, "Baru dalam versi 0.17: pemisahan bertingkat" membuat saya hampir yakin bahwa Anda harus memperbarui sklearn...
Sergey Bushmanov
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.