Prediksi Probabilitas Hutan Acak vs suara terbanyak


10

Scikit belajar tampaknya menggunakan prediksi probabilistik alih-alih suara mayoritas untuk teknik agregasi model tanpa penjelasan mengapa (1.9.2.1. Hutan Acak).

Apakah ada penjelasan yang jelas mengapa? Lebih lanjut apakah ada makalah yang bagus atau ulasan artikel untuk berbagai teknik agregasi model yang dapat digunakan untuk mengantongi Hutan Acak?

Terima kasih!

Jawaban:


10

Pertanyaan seperti itu selalu dijawab dengan melihat kode, jika Anda fasih menggunakan Python.

RandomForestClassifier.predict, setidaknya dalam versi saat ini 0.16.1, memprediksi kelas dengan estimasi probabilitas tertinggi, seperti yang diberikan oleh predict_proba. ( baris ini )

Dokumentasi untuk predict_probamengatakan:

Probabilitas kelas yang diprediksi dari sampel input dihitung sebagai probabilitas probabilitas kelas yang diprediksi dari pohon-pohon di hutan. Probabilitas kelas satu pohon adalah sebagian kecil sampel dari kelas yang sama dalam daun.

Perbedaan dari metode asli mungkin hanya supaya predictmemberikan prediksi konsisten predict_proba. Hasilnya kadang-kadang disebut "voting lunak", daripada suara mayoritas "keras" yang digunakan dalam koran Breiman asli. Saya tidak dapat dalam pencarian cepat menemukan perbandingan yang tepat dari kinerja kedua metode, tetapi mereka berdua tampaknya cukup masuk akal dalam situasi ini.

The predictdokumentasi di terbaik cukup menyesatkan; Saya telah mengirimkan permintaan penarikan untuk memperbaikinya.

Jika Anda ingin melakukan prediksi suara mayoritas, inilah fungsi untuk melakukannya. Sebut seperti predict_majvote(clf, X)suka daripada clf.predict(X). (Berdasarkan predict_proba; hanya diuji ringan, tapi saya pikir itu harus berhasil.)

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

Pada kasus sintetis bodoh yang saya coba, prediksi setuju dengan predictmetode ini setiap waktu.


Jawaban yang bagus, Dougal! Terima kasih telah meluangkan waktu untuk menjelaskan ini dengan cermat. Harap pertimbangkan juga untuk menumpuk overflow dan menjawab pertanyaan ini di sana .
user1745038

1
Ada juga makalah, di sini , yang membahas prediksi probabilistik.
user1745038
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.