Bagaimana cara kerja parameter class_weight di scikit-learn?


116

Saya mengalami banyak masalah dalam memahami bagaimana class_weightparameter dalam Regresi Logistik scikit-learn beroperasi.

Situasi

Saya ingin menggunakan regresi logistik untuk melakukan klasifikasi biner pada kumpulan data yang sangat tidak seimbang. Kelas diberi label 0 (negatif) dan 1 (positif) dan data yang diamati memiliki rasio sekitar 19: 1 dengan mayoritas sampel memiliki hasil negatif.

Upaya Pertama: Menyiapkan Data Pelatihan Secara Manual

Saya membagi data yang saya miliki menjadi beberapa set terpisah untuk pelatihan dan pengujian (sekitar 80/20). Kemudian saya secara acak mengambil sampel data pelatihan dengan tangan untuk mendapatkan data pelatihan dalam proporsi yang berbeda dari 19: 1; dari 2: 1 -> 16: 1.

Saya kemudian melatih regresi logistik pada subset data pelatihan yang berbeda ini dan memplot recall (= TP / (TP + FN)) sebagai fungsi dari proporsi pelatihan yang berbeda. Tentu saja, penarikan dihitung pada sampel TEST terputus-putus yang memiliki proporsi teramati 19: 1. Perhatikan, meskipun saya melatih model yang berbeda pada data pelatihan yang berbeda, saya menghitung penarikan kembali untuk semuanya pada data pengujian (terputus-putus) yang sama.

Hasilnya seperti yang diharapkan: penarikan kembali sekitar 60% pada proporsi pelatihan 2: 1 dan jatuh agak cepat pada saat mencapai 16: 1. Ada beberapa proporsi 2: 1 -> 6: 1 di mana penarikan cukup di atas 5%.

Percobaan Kedua: Pencarian Grid

Selanjutnya, saya ingin menguji parameter regularisasi yang berbeda, jadi saya menggunakan GridSearchCV dan membuat kisi dari beberapa nilai Cparameter serta class_weightparameter. Untuk menerjemahkan proporsi n: m saya dari sampel pelatihan negatif: positif ke dalam kamus bahasa class_weightsaya pikir saya hanya menentukan beberapa kamus sebagai berikut:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

dan saya juga termasuk Nonedan auto.

Kali ini hasilnya benar-benar aneh. Semua penarikan saya keluar kecil (<0,05) untuk setiap nilai class_weightkecuali auto. Jadi saya hanya bisa berasumsi bahwa pemahaman saya tentang cara mengatur class_weightkamus itu salah. Menariknya, class_weightnilai 'auto' dalam pencarian grid sekitar 59% untuk semua nilai C, dan saya kira itu seimbang dengan 1: 1?

Pertanyaan saya

  1. Bagaimana Anda menggunakan dengan benar class_weightuntuk mencapai keseimbangan yang berbeda dalam data pelatihan dari apa yang sebenarnya Anda berikan? Secara khusus, kamus apa yang saya berikan class_weightuntuk menggunakan proporsi n: m sampel pelatihan negatif: positif?

  2. Jika Anda meneruskan berbagai class_weightkamus ke GridSearchCV, selama validasi silang akankah itu menyeimbangkan kembali data lipatan pelatihan menurut kamus tetapi menggunakan proporsi sampel yang diberikan secara benar untuk menghitung fungsi penilaian saya pada lipatan tes? Ini penting karena metrik apa pun hanya berguna bagi saya jika berasal dari data dalam proporsi yang diamati.

  3. Apa autonilai class_weightdo sejauh proporsi? Saya membaca dokumentasinya dan saya berasumsi "menyeimbangkan data berbanding terbalik dengan frekuensinya" berarti menjadikannya 1: 1. Apakah ini benar? Jika tidak, dapatkah seseorang menjelaskan?


Saat seseorang menggunakan class_weight, fungsi kerugian akan dimodifikasi. Misalnya, alih-alih entropi silang, ia menjadi entropi silang berbobot. menujudatascience.com/…
prashanth

Jawaban:


123

Pertama, mungkin tidak baik untuk hanya mengingat saja. Anda dapat dengan mudah mencapai perolehan kembali 100% dengan mengklasifikasikan semuanya sebagai kelas positif. Saya biasanya menyarankan menggunakan AUC untuk memilih parameter, dan kemudian menemukan ambang untuk titik operasi (katakanlah tingkat presisi tertentu) yang Anda minati.

Untuk cara class_weightkerjanya: Ini menghukum kesalahan dalam sampel class[i]dengan class_weight[i]daripada 1. Jadi bobot kelas yang lebih tinggi berarti Anda ingin lebih menekankan pada kelas. Dari apa yang Anda katakan tampaknya kelas 0 19 kali lebih sering daripada kelas 1. Jadi Anda harus meningkatkan class_weightkelas 1 relatif terhadap kelas 0, katakan {0: .1, 1: .9}. Jika class_weighttidak berjumlah 1, itu pada dasarnya akan mengubah parameter regularisasi.

Untuk cara class_weight="auto"kerjanya, Anda bisa melihat pembahasan ini . Dalam versi dev, Anda dapat menggunakan class_weight="balanced", yang lebih mudah dipahami: ini pada dasarnya berarti mereplikasi kelas yang lebih kecil hingga Anda memiliki sampel sebanyak di kelas yang lebih besar, tetapi secara implisit.


1
Terima kasih! Pertanyaan singkat: Saya menyebutkan penarikan untuk kejelasan dan sebenarnya saya mencoba memutuskan AUC mana yang akan digunakan sebagai ukuran saya. Pemahaman saya adalah bahwa saya harus memaksimalkan area di bawah kurva KOP atau area di bawah kurva recall vs. presisi untuk menemukan parameter. Setelah memilih parameter dengan cara ini, saya yakin saya memilih ambang batas untuk klasifikasi dengan menggeser sepanjang kurva. Apakah ini yang kamu maksud? Jika demikian, mana dari dua kurva yang paling masuk akal untuk dilihat jika tujuan saya adalah untuk menangkap TP sebanyak mungkin? Juga, terima kasih atas pekerjaan dan kontribusi Anda untuk scikit-learn !!!
kilgoretrout

1
Saya pikir menggunakan ROC akan menjadi cara yang lebih standar untuk digunakan, tetapi menurut saya tidak akan ada perbedaan besar. Anda memang membutuhkan beberapa kriteria untuk memilih titik pada kurva.
Andreas Mueller

3
@MiNdFrEaK Menurut saya yang dimaksud Andrew adalah penaksir mereplikasi sampel di kelas minoritas, sehingga sampel dari kelas yang berbeda seimbang. Ini hanya oversampling secara implisit.
Shawn TIAN

8
@MiNdFrEaK dan Shawn Tian: Pengklasifikasi berbasis SV tidak menghasilkan lebih banyak sampel dari kelas yang lebih kecil saat Anda menggunakan 'balanced'. Ini benar-benar menghukum kesalahan yang dibuat di kelas yang lebih kecil. Mengatakan sebaliknya adalah kesalahan dan menyesatkan, terutama dalam kumpulan data besar ketika Anda tidak mampu membuat lebih banyak sampel. Jawaban ini harus diedit.
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weight Bobot kelas akan digunakan secara berbeda bergantung pada algoritme: untuk model linier (seperti SVM linier atau regresi logistik), bobot kelas akan mengubah fungsi kerugian sebesar pembobotan kerugian setiap sampel dengan bobot kelasnya. Untuk algoritme berbasis pohon, bobot kelas akan digunakan untuk menilai ulang kriteria pemisahan. Namun perlu dicatat bahwa rebalancing ini tidak memperhitungkan bobot sampel di setiap kelas.
prashanth

2

Jawaban pertama bagus untuk memahami cara kerjanya. Tetapi saya ingin memahami bagaimana saya harus menggunakannya dalam praktik.

RINGKASAN

  • untuk data yang cukup tidak seimbang TANPA derau, tidak banyak perbedaan dalam penerapan bobot kelas
  • untuk data cukup tidak seimbang DENGAN noise dan sangat tidak seimbang, lebih baik menerapkan bobot kelas
  • param class_weight="balanced"berfungsi dengan baik jika Anda tidak ingin mengoptimalkan secara manual
  • dengan class_weight="balanced"Anda menangkap lebih banyak peristiwa sebenarnya (ingatan TRUE lebih tinggi) tetapi juga Anda lebih cenderung mendapatkan peringatan palsu (presisi TRUE lebih rendah)
    • akibatnya, total% TRUE mungkin lebih tinggi dari yang sebenarnya karena semua positif palsu
    • AUC mungkin menyesatkan Anda di sini jika alarm palsu menjadi masalah
  • tidak perlu mengubah ambang batas keputusan menjadi% ketidakseimbangan, bahkan untuk ketidakseimbangan yang kuat, ok untuk mempertahankan 0,5 (atau sekitar itu tergantung pada apa yang Anda butuhkan)

NB

Hasilnya mungkin berbeda saat menggunakan RF atau GBM. sklearn tidak memiliki class_weight="balanced" GBM tetapi lightgbm memilikiLGBMClassifier(is_unbalance=False)

KODE

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.