Plot matriks korelasi menggunakan panda


212

Saya memiliki kumpulan data dengan sejumlah besar fitur, jadi menganalisis matriks korelasi menjadi sangat sulit. Saya ingin memplot matriks korelasi yang kita dapatkan menggunakan dataframe.corr()fungsi dari panda library. Apakah ada fungsi bawaan yang disediakan oleh panda library untuk mem-plot matriks ini?


Jawaban terkait dapat ditemukan di sini Membuat peta panas dari panda DataFrame
joelostblom

Jawaban:


292

Anda dapat menggunakan pyplot.matshow() dari matplotlib:

import matplotlib.pyplot as plt

plt.matshow(dataframe.corr())
plt.show()

Edit:

Dalam komentar adalah permintaan untuk bagaimana mengubah label centang sumbu. Berikut ini adalah versi deluxe yang digambar pada ukuran figur yang lebih besar, memiliki label sumbu yang cocok dengan kerangka data, dan legenda colorbar untuk menafsirkan skala warna.

Saya termasuk cara menyesuaikan ukuran dan rotasi label, dan saya menggunakan rasio angka yang membuat colorbar dan gambar utama keluar sama tingginya.

f = plt.figure(figsize=(19, 15))
plt.matshow(df.corr(), fignum=f.number)
plt.xticks(range(df.shape[1]), df.columns, fontsize=14, rotation=45)
plt.yticks(range(df.shape[1]), df.columns, fontsize=14)
cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.title('Correlation Matrix', fontsize=16);

contoh plot korelasi


1
Saya pasti melewatkan sesuatu:AttributeError: 'module' object has no attribute 'matshow'
Tom Russell

1
@ TomRussell Apakah Anda melakukannya import matplotlib.pyplot as plt?
joelostblom

1
Saya ingin berpikir saya melakukannya! :-)
Tom Russell

7
apakah Anda tahu cara menampilkan nama kolom yang sebenarnya di plot?
WebQube

2
@ Cecilia Saya telah menyelesaikan masalah ini dengan mengubah parameter rotasi menjadi 90
ikbel benabdessamad

182

Jika tujuan utama Anda adalah memvisualisasikan matriks korelasi, alih-alih membuat plot sendiri, pandas opsi penataan yang nyaman adalah solusi bawaan yang layak:

import pandas as pd
import numpy as np

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
corr = df.corr()
corr.style.background_gradient(cmap='coolwarm')
# 'RdBu_r' & 'BrBG' are other good diverging colormaps

masukkan deskripsi gambar di sini

Perhatikan bahwa ini harus di backend yang mendukung rendering HTML, seperti Notebook JupyterLab. (Teks cahaya otomatis pada latar belakang gelap berasal dari PR yang ada dan bukan versi terbaru yang dirilis, pandas0,23).


Styling

Anda dapat dengan mudah membatasi ketepatan digit:

corr.style.background_gradient(cmap='coolwarm').set_precision(2)

masukkan deskripsi gambar di sini

Atau singkirkan digitnya sepenuhnya jika Anda lebih suka matriks tanpa anotasi:

corr.style.background_gradient(cmap='coolwarm').set_properties(**{'font-size': '0pt'})

masukkan deskripsi gambar di sini

Dokumentasi penataan gaya juga mencakup petunjuk gaya yang lebih maju, seperti cara mengubah tampilan sel yang dituju penunjuk tetikus. Untuk menyimpan hasil, Anda dapat mengembalikan HTML dengan menambahkanrender() metode dan kemudian menulisnya ke file (atau hanya mengambil tangkapan layar untuk keperluan yang kurang formal).


Perbandingan waktu

Dalam pengujian saya, style.background_gradient()4x lebih cepat dari plt.matshow()dan 120x lebih cepat daripada sns.heatmap()dengan matriks 10x10. Sayangnya itu tidak skala juga plt.matshow(): keduanya membutuhkan waktu yang sama untuk matriks 100x100, dan plt.matshow()10x lebih cepat untuk matriks 1000x1000.


Penghematan

Ada beberapa cara yang mungkin untuk menyimpan kerangka data bergaya:

  • Kembalikan HTML dengan menambahkan render()metode dan kemudian tulis output ke file.
  • Simpan sebagai .xslx file dengan pemformatan bersyarat dengan menambahkan to_excel()metode.
  • Gabungkan dengan imgkit untuk menyimpan bitmap
  • Ambil tangkapan layar (untuk keperluan yang kurang formal).

Perbarui untuk panda> = 0,24

Dengan mengatur axis=None, sekarang dimungkinkan untuk menghitung warna berdasarkan seluruh matriks daripada per kolom atau per baris:

corr.style.background_gradient(cmap='coolwarm', axis=None)

masukkan deskripsi gambar di sini


2
Jika ada cara untuk mengekspor adalah sebagai gambar, itu akan sangat bagus!
Kristada673

1
Terima kasih! Anda pasti membutuhkan palet divergenimport seaborn as sns corr = df.corr() cm = sns.light_palette("green", as_cmap=True) cm = sns.diverging_palette(220, 20, sep=20, as_cmap=True) corr.style.background_gradient(cmap=cm).set_precision(2)
stallingOne

1
@ installing One Good point, saya seharusnya tidak memasukkan nilai negatif dalam contoh, saya mungkin mengubahnya nanti. Hanya untuk referensi bagi orang yang membaca ini, Anda tidak perlu membuat cmap divergen khusus dengan seaborn (walaupun yang di komentar di atas terlihat cukup apik), Anda juga dapat menggunakan cmaps divergen bawaan dari matplotlib, mis corr.style.background_gradient(cmap='coolwarm'). Saat ini tidak ada cara untuk memusatkan cmap pada nilai tertentu, yang dapat menjadi ide bagus dengan cmaps yang berbeda.
joelostblom

1
@rovyko Apakah Anda menggunakan panda> = 0.24.0?
joelostblom

2
Plot ini sangat bagus secara visual, tetapi pertanyaan @ Kristada673 cukup relevan, bagaimana Anda akan mengekspornya?
Erfan

89

Coba fungsi ini, yang juga menampilkan nama variabel untuk matriks korelasi:

def plot_corr(df,size=10):
    '''Function plots a graphical correlation matrix for each pair of columns in the dataframe.

    Input:
        df: pandas DataFrame
        size: vertical and horizontal size of the plot'''

    corr = df.corr()
    fig, ax = plt.subplots(figsize=(size, size))
    ax.matshow(corr)
    plt.xticks(range(len(corr.columns)), corr.columns);
    plt.yticks(range(len(corr.columns)), corr.columns);

6
plt.xticks(range(len(corr.columns)), corr.columns, rotation='vertical')jika Anda ingin orientasi vertikal nama kolom pada sumbu x
nishant

Hal grafis lain, tetapi menambahkan plt.tight_layout()mungkin juga berguna untuk nama kolom yang panjang.
user3017048

86

Versi peta panas Seaborn:

import seaborn as sns
corr = dataframe.corr()
sns.heatmap(corr, 
            xticklabels=corr.columns.values,
            yticklabels=corr.columns.values)

9
Seaborn heatmap sangat bagus tetapi kinerjanya buruk pada matriks besar. Metode matshow dari matplotlib jauh lebih cepat.
anilbey

3
Seaborn dapat secara otomatis menyimpulkan ticklabels dari nama kolom.
Tulio Casagrande

80

Anda dapat mengamati hubungan antara fitur baik dengan menggambar peta panas dari seaborn atau sebar matriks dari panda.

Matriks Sebar:

pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');

Jika Anda juga ingin memvisualisasikan kemiringan masing-masing fitur - gunakan pasangan seaborn.

sns.pairplot(dataframe)

Sns Heatmap:

import seaborn as sns

f, ax = pl.subplots(figsize=(10, 8))
corr = dataframe.corr()
sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True),
            square=True, ax=ax)

Outputnya akan berupa peta korelasi fitur. yaitu lihat contoh di bawah ini.

masukkan deskripsi gambar di sini

Korelasi antara bahan makanan dan deterjen tinggi. Demikian pula:

Pdoducts Dengan Korelasi Tinggi:
  1. Toko Kelontong dan Deterjen.
Produk Dengan Korelasi Menengah:
  1. Susu dan Bahan Makanan
  2. Milk and Detergents_Paper
Produk Dengan Korelasi Rendah:
  1. Susu dan Deli
  2. Beku dan segar.
  3. Beku dan Deli.

Dari Pairplots: Anda dapat mengamati serangkaian relasi yang sama dari pairplots atau scatter matrix. Tetapi dari sini kita dapat mengatakan apakah data terdistribusi secara normal atau tidak.

masukkan deskripsi gambar di sini

Catatan: Di atas adalah grafik yang sama yang diambil dari data, yang digunakan untuk menggambar peta panas.


3
Saya pikir itu harus .plt bukan .pl (jika ini merujuk ke matplotlib)
ghukill

2
@ ghukill Tidak perlu. Dia bisa menyebutnyafrom matplotlib import pyplot as pl
Jeru Luke

cara mengatur batas korelasi antara -1 hingga +1 selalu, dalam plot korelasi
debaonline4u

7

Anda dapat menggunakan metode imshow () dari matplotlib

import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(X.columns))]
plt.xticks(tick_marks, X.columns, rotation='vertical')
plt.yticks(tick_marks, X.columns)
plt.show()

5

Jika bingkai data dfAnda, Anda cukup menggunakan:

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 10))
sns.heatmap(df.corr(), annot=True)

3

grafis statmodels juga memberikan tampilan yang bagus dari matriks korelasi

import statsmodels.api as sm
import matplotlib.pyplot as plt

corr = dataframe.corr()
sm.graphics.plot_corr(corr, xnames=list(corr.columns))
plt.show()

3

Untuk kelengkapan, solusi paling sederhana yang saya tahu dengan seaborn pada akhir 2019, jika seseorang menggunakan Jupyter :

import seaborn as sns
sns.heatmap(dataframe.corr())

1

Bersamaan dengan metode lain juga baik untuk memiliki pairplot yang akan memberikan plot pencar untuk semua kasus-

import pandas as pd
import numpy as np
import seaborn as sns
rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
sns.pairplot(df)

0

Bentuk matriks korelasi, dalam kasus saya zdf adalah kerangka data yang saya perlukan melakukan matriks korelasi.

corrMatrix =zdf.corr()
corrMatrix.to_csv('sm_zscaled_correlation_matrix.csv');
html = corrMatrix.style.background_gradient(cmap='RdBu').set_precision(2).render()

# Writing the output to a html file.
with open('test.html', 'w') as f:
   print('<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-widthinitial-scale=1.0"><title>Document</title></head><style>table{word-break: break-all;}</style><body>' + html+'</body></html>', file=f)

Lalu kita bisa mengambil screenshot. atau konversi html ke file gambar.

Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.