Sebar plot di Pandas / Pyplot: Cara plot berdasarkan kategori


90

Saya mencoba membuat plot sebar sederhana di pyplot menggunakan objek Pandas DataFrame, tetapi ingin cara yang efisien untuk merencanakan dua variabel tetapi simbolnya ditentukan oleh kolom ketiga (kunci). Saya telah mencoba berbagai cara menggunakan df.groupby, tetapi tidak berhasil. Contoh skrip df ada di bawah. Ini mewarnai penanda menurut 'key1', tapi saya ingin melihat legenda dengan kategori 'key1'. Apakah saya dekat? Terima kasih.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()

Jawaban:


120

Anda dapat menggunakan scatteruntuk ini, tetapi itu membutuhkan nilai numerik untuk Anda key1, dan Anda tidak akan memiliki legenda, seperti yang Anda perhatikan.

Lebih baik hanya digunakan plotuntuk kategori terpisah seperti ini. Sebagai contoh:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

masukkan deskripsi gambar di sini

Jika Anda ingin hal-hal terlihat seperti pandasgaya default , perbarui saja rcParamsdengan lembar gaya pandas dan gunakan generator warnanya. (Saya juga sedikit mengubah legenda):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

masukkan deskripsi gambar di sini


Mengapa dalam contoh RGB di atas simbol ditampilkan dua kali dalam legenda? Bagaimana cara menampilkannya hanya sekali?
Steve Schulist

1
@SteveSchulist - Gunakan ax.legend(numpoints=1)untuk menampilkan hanya satu penanda. Ada dua, seperti pada a Line2D, sering kali ada garis yang menghubungkan kedua penanda.
Joe Kington

Kode ini hanya berfungsi untuk saya setelah menambahkan plt.hold(True)setelah ax.plot()perintah. Tahu kenapa?
Yuval Atzmon

set_color_cycle() tidak lagi digunakan di matplotlib 1.5. Ada set_prop_cycle(), sekarang.
ale

52

Ini mudah dilakukan dengan Seaborn (pip install seaborn ) sebagai satu perjalanan

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1") :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

masukkan deskripsi gambar di sini

Berikut adalah kerangka data untuk referensi:

masukkan deskripsi gambar di sini

Karena Anda memiliki tiga kolom variabel dalam data Anda, Anda mungkin ingin memplot semua dimensi berpasangan dengan:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

masukkan deskripsi gambar di sini

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ adalah opsi lain.


19

Dengan plt.scatter, saya hanya dapat memikirkan satu: menggunakan artis proxy:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

Dan hasilnya adalah:

masukkan deskripsi gambar di sini


10

Anda dapat menggunakan df.plot.scatter, dan meneruskan array ke c = argumen yang menentukan warna setiap titik:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

masukkan deskripsi gambar di sini


4

Anda juga dapat mencoba Altair atau ggpot yang berfokus pada visualisasi deklaratif.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Kode Altair

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

masukkan deskripsi gambar di sini

kode ggplot

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

masukkan deskripsi gambar di sini


4

Dari matplotlib 3.1 dan seterusnya, Anda dapat menggunakan .legend_elements(). Contoh ditampilkan dalam pembuatan legenda otomatis . Keuntungannya adalah satu panggilan pencar dapat digunakan.

Pada kasus ini:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

masukkan deskripsi gambar di sini

Jika kunci tidak langsung diberikan sebagai angka, maka akan terlihat seperti

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

masukkan deskripsi gambar di sini


Saya mendapat pesan kesalahan yang mengatakan objek 'PathCollection' tidak memiliki atribut 'legends_elements'. Kode saya adalah sebagai berikut. fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel

1
@NandishPatel Periksa kalimat pertama dari jawaban ini. Juga pastikan untuk tidak bingung legends_elementsdan legend_elements.
ImportanceOfBeingErnest

Ya terima kasih. Itu salah ketik (legenda / legenda). Saya sedang mengerjakan sesuatu sejak 6 jam terakhir sehingga versi Matplotlib tidak terpikir oleh saya. Saya pikir saya menggunakan yang terbaru. Saya bingung bahwa dokumentasi mengatakan ada metode seperti itu tetapi kode memberikan kesalahan. Terima kasih lagi. Saya bisa tidur sekarang.
Nandish Patel

2

Ini agak Hacky, tapi Anda bisa menggunakan one1sebagai Float64Indexuntuk melakukan segala sesuatu dalam satu pergi:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)

masukkan deskripsi gambar di sini

Perhatikan bahwa mulai 0.20.3, pengurutan indeks diperlukan , dan legenda adalah agak miring .


1

seaborn memiliki fungsi pembungkus scatterplotyang melakukannya dengan lebih efisien.

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.