Bingkai Data Pandas ke DMatrix


14

Saya mencoba menjalankan xgboost di scikit belajar. Dan saya hanya menggunakan Panda untuk memuat data ke dalam dataframe. Bagaimana saya bisa menggunakan panda df dengan xgboost. Saya bingung dengan rutin DMatrix yang diperlukan untuk menjalankan xgboost algo.

Jawaban:


21

Anda dapat menggunakan metode dataframe .valuesuntuk mengakses data mentah setelah Anda memanipulasi kolom sesuai kebutuhan.

Misalnya

train = pd.read_csv("train.csv")
target = train['target']
train = train.drop(['ID','target'],axis=1)
test = pd.read_csv("test.csv")
test = test.drop(['ID'],axis=1)

xgtrain = xgb.DMatrix(train.values, target.values)
xgtest = xgb.DMatrix(test.values)

Jelas Anda mungkin perlu mengubah kolom mana yang Anda jatuhkan atau gunakan sebagai target pelatihan. Di atas adalah untuk kompetisi Kaggle, jadi tidak ada data target untuk xgtest(itu ditahan oleh penyelenggara).


Ketika mencoba dengan cara ini xgb.DMatrix(X_train.values, y_train.values)saya melihatTypeError: can not initialize DMatrix from dict
javadba

@javadba: Ini pasti bekerja pada tahun 2016 di mcahine saya! Saya tidak dapat menguji ini saat ini karena saya tidak dapat menginstal xgboost. Mungkin beberapa kode perpustakaan telah berubah. Kemungkinan besar ada sesuatu yang berbeda dengan situasi Anda. Saya menemukan stackoverflow.com/questions/35402461/... tetapi itu hanya menyarankan Anda untuk melakukan persis apa yang dilakukan jawaban ini (yaitu penggunaan .values)
Neil Slater


7

Anda sekarang dapat menggunakan Pandas DataFrames langsung dengan XGBoost. Pasti berfungsi dengan xgboost 0.81.

Misalnya di mana X_train, X_val, y_train, dan y_val adalah DataFrames:

import xgboost as xgb

mod = xgb.XGBRegressor(
    gamma=1,                 
    learning_rate=0.01,
    max_depth=3,
    n_estimators=10000,                                                                    
    subsample=0.8,
    random_state=34
) 

mod.fit(X_train, y_train)
predictions = mod.predict(X_val)
rmse = sqrt(mean_squared_error(y_val, predictions))
print("score: {0:,.0f}".format(rmse))
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.