Hasilkan peta panas di MatPlotLib menggunakan set data sebar


187

Saya memiliki satu set data X, Y poin (sekitar 10k) yang mudah untuk plot sebagai sebar plot tetapi saya ingin mewakili sebagai peta panas.

Saya melihat contoh-contoh di MatPlotLib dan semuanya tampaknya sudah mulai dengan nilai sel heatmap untuk menghasilkan gambar.

Apakah ada metode yang mengubah sekelompok x, y, semuanya berbeda, menjadi peta panas (di mana zona dengan frekuensi x yang lebih tinggi, y akan menjadi "lebih hangat")?


Jawaban:


182

Jika Anda tidak ingin hexagon, Anda dapat menggunakan histogram2dfungsi numpy :

import numpy as np
import numpy.random
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

heatmap, xedges, yedges = np.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.imshow(heatmap.T, extent=extent, origin='lower')
plt.show()

Ini menghasilkan peta panas 50x50. Jika Anda ingin, katakanlah, 512x384, Anda dapat melakukan bins=(512, 384)panggilan ke histogram2d.

Contoh: Contoh peta matplotlib panas


1
Saya tidak bermaksud menjadi idiot, tetapi bagaimana Anda benar-benar memiliki output ini ke file PNG / PDF bukannya hanya ditampilkan dalam sesi IPython interaktif? Saya mencoba untuk mendapatkan ini sebagai semacam axescontoh normal , di mana saya dapat menambahkan judul, label sumbu, dll dan kemudian melakukan yang normal savefig()seperti yang akan saya lakukan untuk plot matplotlib khas lainnya.
dapatkan keturunan

3
@ gotgenes: tidak plt.savefig('filename.png')berfungsi? Jika Anda ingin mendapatkan contoh sumbu, gunakan antarmuka berorientasi objek fig = plt.figure() ax = fig.gca() ax.imshow(...) fig.savefig(...)
Matplotlib

1
Memang terima kasih! Saya kira saya tidak sepenuhnya mengerti bahwa imshow()ada pada kategori fungsi yang sama dengan scatter(). Jujur saya tidak mengerti mengapa imshow()mengubah array 2d mengapung menjadi blok warna yang sesuai, sedangkan saya mengerti apa scatter()yang harus dilakukan dengan array seperti itu.
dapatkan

14
Peringatan tentang penggunaan imshow untuk memplot histogram 2d dari nilai x / y seperti ini: secara default, imshow memplot asal dari sudut kiri atas dan mentransposisi gambar. Apa yang akan saya lakukan untuk mendapatkan orientasi yang sama dengan sebaran plot adalahplt.imshow(heatmap.T, extent=extent, origin = 'lower')
Jamie

7
Bagi mereka yang ingin melakukan colorbar logaritmik lihat pertanyaan ini stackoverflow.com/questions/17201172/… dan cukup lakukanfrom matplotlib.colors import LogNorm plt.imshow(heatmap, norm=LogNorm()) plt.colorbar()
tommy.carstensen

109

Dalam leksikon Matplotlib , saya pikir Anda ingin plot hexbin .

Jika Anda tidak terbiasa dengan jenis plot ini, itu hanya histogram bivariat di mana xy-plane di-tellellated oleh kisi-kisi segi enam biasa.

Jadi dari histogram, Anda bisa menghitung jumlah titik yang jatuh di setiap segi enam, diskritkan wilayah plot sebagai satu set jendela , tetapkan setiap titik ke salah satu jendela ini; akhirnya, petakan windows ke array warna , dan Anda punya diagram hexbin.

Meskipun lebih jarang digunakan daripada misalnya, lingkaran, atau kuadrat, hexagon adalah pilihan yang lebih baik untuk geometri wadah binning intuitif:

  • segi enam memiliki simetri tetangga terdekat (mis., kotak persegi tidak, misalnya, jarak dari titik di perbatasan kotak ke titik di dalam kotak itu tidak sama di mana-mana) dan

  • hexagon adalah n-poligon tertinggi yang memberikan tessellation bidang reguler (yaitu, Anda dapat dengan aman memodel ulang lantai dapur Anda dengan ubin berbentuk heksagonal karena Anda tidak akan memiliki ruang kosong antara ubin ketika Anda selesai - tidak berlaku untuk semua lainnya yang lebih tinggi-n, n> = 7, poligon).

( Matplotlib menggunakan istilah heksbin plot; demikian juga (AFAIK) semua perpustakaan petak untuk R ; masih saya tidak tahu apakah ini istilah yang diterima secara umum untuk plot jenis ini, meskipun saya curiga ada kemungkinan bahwa hexbin pendek. untuk heksagonal binning , yang menggambarkan langkah penting dalam menyiapkan data untuk ditampilkan.)


from matplotlib import pyplot as PLT
from matplotlib import cm as CM
from matplotlib import mlab as ML
import numpy as NP

n = 1e5
x = y = NP.linspace(-5, 5, 100)
X, Y = NP.meshgrid(x, y)
Z1 = ML.bivariate_normal(X, Y, 2, 2, 0, 0)
Z2 = ML.bivariate_normal(X, Y, 4, 1, 1, 1)
ZD = Z2 - Z1
x = X.ravel()
y = Y.ravel()
z = ZD.ravel()
gridsize=30
PLT.subplot(111)

# if 'bins=None', then color of each hexagon corresponds directly to its count
# 'C' is optional--it maps values to x-y coordinates; if 'C' is None (default) then 
# the result is a pure 2D histogram 

PLT.hexbin(x, y, C=z, gridsize=gridsize, cmap=CM.jet, bins=None)
PLT.axis([x.min(), x.max(), y.min(), y.max()])

cb = PLT.colorbar()
cb.set_label('mean value')
PLT.show()   

masukkan deskripsi gambar di sini


Apa artinya "hexagon memiliki simetri tetangga terdekat"? Anda mengatakan bahwa "jarak dari titik di perbatasan kotak dan titik di dalam kotak itu tidak sama di mana-mana" tetapi jarak ke apa?
Jaan

9
Untuk segi enam, jarak dari pusat ke titik yang menghubungkan dua sisi juga lebih panjang daripada dari pusat ke tengah sisi, hanya rasionya lebih kecil (2 / sqrt (3) ≈ 1,15 untuk hexagon vs sqrt (2) ≈ 1,41 untuk persegi). Satu-satunya bentuk di mana jarak dari pusat ke setiap titik di perbatasan sama adalah lingkaran.
Jaan

5
@ Jaan Untuk segi enam, setiap tetangga pada jarak yang sama. Tidak ada masalah dengan 8-lingkungan atau 4-lingkungan. Tidak ada tetangga diagonal, hanya satu jenis tetangga.
isarandi

@doug Bagaimana Anda memilih gridsize=parameter. Saya ingin memilihnya sedemikian rupa, sehingga segi enam hanya menyentuh tanpa tumpang tindih. Saya perhatikan bahwa gridsize=100akan menghasilkan hexagon yang lebih kecil, tetapi bagaimana cara memilih nilai yang tepat?
Alexander Cska

40

Sunting: Untuk perkiraan jawaban Alejandro yang lebih baik, lihat di bawah.

Saya tahu ini adalah pertanyaan lama, tetapi ingin menambahkan sesuatu ke server Alejandro: Jika Anda ingin gambar yang dihaluskan tanpa menggunakan py-sphviewer, Anda dapat menggunakan np.histogram2ddan menerapkan filter gaussian (dari scipy.ndimage.filters) ke peta panas:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.ndimage.filters import gaussian_filter


def myplot(x, y, s, bins=1000):
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins)
    heatmap = gaussian_filter(heatmap, sigma=s)

    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    return heatmap.T, extent


fig, axs = plt.subplots(2, 2)

# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)

sigmas = [0, 16, 32, 64]

for ax, s in zip(axs.flatten(), sigmas):
    if s == 0:
        ax.plot(x, y, 'k.', markersize=5)
        ax.set_title("Scatter plot")
    else:
        img, extent = myplot(x, y, s)
        ax.imshow(img, extent=extent, origin='lower', cmap=cm.jet)
        ax.set_title("Smoothing with  $\sigma$ = %d" % s)

plt.show()

Menghasilkan:

Keluarkan gambar

Plot sebar dan s = 16 diplot di atas satu sama lain untuk Agape Gal'lo (klik untuk tampilan yang lebih baik):

Di atas satu sama lain


Satu perbedaan yang saya perhatikan dengan pendekatan filter gaussian saya dan pendekatan Alejandro adalah bahwa metodenya menunjukkan struktur lokal jauh lebih baik daripada milik saya. Oleh karena itu saya menerapkan metode tetangga terdekat yang sederhana di tingkat piksel. Metode ini menghitung untuk setiap piksel jumlah terbalik dari jarak ntitik terdekat dalam data. Metode ini pada resolusi tinggi komputasi yang cukup mahal dan saya pikir ada cara yang lebih cepat, jadi beri tahu saya jika Anda memiliki perbaikan.

Pembaruan: Seperti yang saya duga, ada metode yang jauh lebih cepat menggunakan Scipy's scipy.cKDTree. Lihat jawaban Gabriel untuk implementasinya.

Bagaimanapun, ini kode saya:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm


def data_coord2view_coord(p, vlen, pmin, pmax):
    dp = pmax - pmin
    dv = (p - pmin) / dp * vlen
    return dv


def nearest_neighbours(xs, ys, reso, n_neighbours):
    im = np.zeros([reso, reso])
    extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]

    xv = data_coord2view_coord(xs, reso, extent[0], extent[1])
    yv = data_coord2view_coord(ys, reso, extent[2], extent[3])
    for x in range(reso):
        for y in range(reso):
            xp = (xv - x)
            yp = (yv - y)

            d = np.sqrt(xp**2 + yp**2)

            im[y][x] = 1 / np.sum(d[np.argpartition(d.ravel(), n_neighbours)[:n_neighbours]])

    return im, extent


n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)
resolution = 250

fig, axes = plt.subplots(2, 2)

for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 64]):
    if neighbours == 0:
        ax.plot(xs, ys, 'k.', markersize=2)
        ax.set_aspect('equal')
        ax.set_title("Scatter Plot")
    else:
        im, extent = nearest_neighbours(xs, ys, resolution, neighbours)
        ax.imshow(im, origin='lower', extent=extent, cmap=cm.jet)
        ax.set_title("Smoothing over %d neighbours" % neighbours)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])
plt.show()

Hasil:

Smoothing Neighbor Terdekat


1
Suka ini. Grafik sebaik jawaban Alejandro, tetapi tidak ada paket baru yang diperlukan.
Nathan Clement

Sangat bagus ! Tetapi Anda menghasilkan offset dengan metode ini. Anda dapat melihat ini dengan membandingkan grafik pencar normal dengan yang berwarna. Bisakah Anda menambahkan sesuatu untuk memperbaikinya? Atau hanya untuk memindahkan grafik dengan nilai x dan y?
Agape Gal'lo

1
Agape Gal'lo, apa maksudmu dengan offset? Jika Anda memplotnya di atas satu sama lain, mereka cocok (lihat edit posting saya). Mungkin Anda menunda karena lebar hamburan tidak sama persis dengan tiga lainnya.
Jurgy

Terima kasih banyak untuk merencanakan grafik hanya untuk saya! Saya mengerti kesalahan saya: Saya telah memodifikasi "sejauh" untuk mendefinisikan batas x dan y. Saya sekarang memahaminya memodifikasi asal grafik. Lalu, saya punya pertanyaan terakhir: bagaimana saya bisa memperluas batas grafik, bahkan untuk area di mana tidak ada data yang ada? Misalnya, antara -5 hingga +5 untuk x dan y.
Agape Gal'lo

1
Katakan Anda ingin sumbu x bergerak dari -5 ke 5 dan sumbu y dari -3 ke 4; dalam myplotfungsi, menambahkan rangeparameter ke np.histogram2d: np.histogram2d(x, y, bins=bins, range=[[-5, 5], [-3, 4]])dan dalam untuk loop mengatur x dan y lim sumbu: ax.set_xlim([-5, 5]) ax.set_ylim([-3, 4]). Selain itu, secara default, imshowmenjaga rasio aspek identik dengan rasio sumbu Anda (jadi dalam contoh saya rasio 10: 7), tetapi jika Anda ingin agar sesuai dengan jendela plot Anda, tambahkan parameter aspect='auto'ke imshow.
Jurgy

31

Alih-alih menggunakan np.hist2d, yang secara umum menghasilkan histogram yang sangat jelek, saya ingin mendaur ulang py-sphviewer , paket python untuk rendering simulasi partikel menggunakan kernel smoothing adaptive dan yang dapat dengan mudah dipasang dari pip (lihat dokumentasi halaman web). Pertimbangkan kode berikut, yang didasarkan pada contoh:

import numpy as np
import numpy.random
import matplotlib.pyplot as plt
import sphviewer as sph

def myplot(x, y, nb=32, xsize=500, ysize=500):   
    xmin = np.min(x)
    xmax = np.max(x)
    ymin = np.min(y)
    ymax = np.max(y)

    x0 = (xmin+xmax)/2.
    y0 = (ymin+ymax)/2.

    pos = np.zeros([3, len(x)])
    pos[0,:] = x
    pos[1,:] = y
    w = np.ones(len(x))

    P = sph.Particles(pos, w, nb=nb)
    S = sph.Scene(P)
    S.update_camera(r='infinity', x=x0, y=y0, z=0, 
                    xsize=xsize, ysize=ysize)
    R = sph.Render(S)
    R.set_logscale()
    img = R.get_image()
    extent = R.get_extent()
    for i, j in zip(xrange(4), [x0,x0,y0,y0]):
        extent[i] += j
    print extent
    return img, extent

fig = plt.figure(1, figsize=(10,10))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)


# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)

#Plotting a regular scatter plot
ax1.plot(x,y,'k.', markersize=5)
ax1.set_xlim(-3,3)
ax1.set_ylim(-3,3)

heatmap_16, extent_16 = myplot(x,y, nb=16)
heatmap_32, extent_32 = myplot(x,y, nb=32)
heatmap_64, extent_64 = myplot(x,y, nb=64)

ax2.imshow(heatmap_16, extent=extent_16, origin='lower', aspect='auto')
ax2.set_title("Smoothing over 16 neighbors")

ax3.imshow(heatmap_32, extent=extent_32, origin='lower', aspect='auto')
ax3.set_title("Smoothing over 32 neighbors")

#Make the heatmap using a smoothing over 64 neighbors
ax4.imshow(heatmap_64, extent=extent_64, origin='lower', aspect='auto')
ax4.set_title("Smoothing over 64 neighbors")

plt.show()

yang menghasilkan gambar berikut:

masukkan deskripsi gambar di sini

Seperti yang Anda lihat, gambar terlihat cukup bagus, dan kami dapat mengidentifikasi berbagai substruktur di atasnya. Gambar-gambar ini dibangun menyebarkan bobot yang diberikan untuk setiap titik dalam domain tertentu, ditentukan oleh panjang smoothing, yang pada gilirannya diberikan oleh jarak ke tetangga nb lebih dekat (saya telah memilih 16, 32 dan 64 untuk contoh). Jadi, daerah dengan kepadatan lebih tinggi biasanya tersebar di daerah yang lebih kecil dibandingkan dengan daerah dengan kepadatan lebih rendah.

Fungsi myplot hanyalah fungsi yang sangat sederhana yang saya tulis untuk memberikan data x, y ke py-sphviewer untuk melakukan keajaiban.


2
Sebuah komentar untuk siapa saja yang mencoba menginstal py-sphviewer di OSX: Saya mengalami cukup banyak kesulitan, lihat: github.com/alejandrobll/py-sphviewer/issues/3
Sam Finnigan

Sayang sekali itu tidak bekerja dengan python3. Itu menginstal, tetapi kemudian crash ketika Anda mencoba menggunakannya ...
Fábio Dias

1
@Fabio Dias, Versi terbaru (1.1.x) sekarang bekerja dengan Python 3.
Alejandro

29

Jika Anda menggunakan 1.2.x

import numpy as np
import matplotlib.pyplot as plt

x = np.random.randn(100000)
y = np.random.randn(100000)
plt.hist2d(x,y,bins=100)
plt.show()

gaussian_2d_heat_map


17

Seaborn sekarang memiliki fungsi jointplot yang seharusnya bekerja dengan baik di sini:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

sns.jointplot(x=x, y=y, kind='hex')
plt.show()

gambar demo


Sederhana, cantik, dan bermanfaat secara analitis.
ryanjdillon

@wordsbagaimana cara Anda membuat data 600k dibaca secara visual menggunakan ini? (cara mengubah ukuran)
nrmb

Saya tidak begitu yakin apa yang Anda maksud; mungkin lebih baik Anda mengajukan pertanyaan terpisah dan menautkannya di sini. Maksudmu mengubah ukuran seluruh ara? Pertama buat angka dengan fig = plt.figure(figsize=(12, 12)), kemudian dapatkan sumbu saat ini ax=plt.gca(), kemudian tambahkan argumen ax=axke jointplotfungsi.
kata

@words Selanjutnya, tolong jawab pertanyaan ini: stackoverflow.com/questions/50997662/… terima kasih
ebrahimi

4

dan pertanyaan awal adalah ... bagaimana cara mengubah nilai sebar ke nilai kotak, kan? histogram2dtidak menghitung frekuensi per sel, namun, jika Anda memiliki data lain per sel dari hanya frekuensi, Anda akan memerlukan beberapa pekerjaan tambahan untuk dilakukan.

x = data_x # between -10 and 4, log-gamma of an svc
y = data_y # between -4 and 11, log-C of an svc
z = data_z #between 0 and 0.78, f1-values from a difficult dataset

Jadi, saya punya dataset dengan hasil-Z untuk koordinat X dan Y. Namun, saya menghitung beberapa poin di luar bidang minat (kesenjangan besar), dan banyak poin di bidang minat kecil.

Ya di sini menjadi lebih sulit tetapi juga lebih menyenangkan. Beberapa perpustakaan (maaf):

from matplotlib import pyplot as plt
from matplotlib import cm
import numpy as np
from scipy.interpolate import griddata

pyplot adalah mesin grafis saya hari ini, cm adalah berbagai peta warna dengan beberapa pilihan initeresting. numpy untuk perhitungan, dan data grid untuk melampirkan nilai ke jaringan tetap.

Yang terakhir ini penting terutama karena frekuensi titik xy tidak terdistribusi secara merata dalam data saya. Pertama, mari kita mulai dengan beberapa batasan yang cocok dengan data saya dan ukuran kisi yang berubah-ubah. Data asli memiliki titik data juga di luar batas x dan y.

#determine grid boundaries
gridsize = 500
x_min = -8
x_max = 2.5
y_min = -2
y_max = 7

Jadi kami telah mendefinisikan kisi dengan 500 piksel antara nilai min dan maks x dan y.

Dalam data saya, ada lebih dari 500 nilai yang tersedia di bidang minat tinggi; sedangkan di bidang berbunga rendah, tidak ada bahkan 200 nilai dalam total grid; antara batas-batas grafis x_mindan x_maxbahkan ada lebih sedikit.

Jadi untuk mendapatkan gambar yang bagus, tugasnya adalah untuk mendapatkan rata-rata nilai bunga tinggi dan mengisi celah di tempat lain.

Saya mendefinisikan grid saya sekarang. Untuk setiap pasangan xx-yy, saya ingin memiliki warna.

xx = np.linspace(x_min, x_max, gridsize) # array of x values
yy = np.linspace(y_min, y_max, gridsize) # array of y values
grid = np.array(np.meshgrid(xx, yy.T))
grid = grid.reshape(2, grid.shape[1]*grid.shape[2]).T

Kenapa bentuknya aneh? scipy.griddata menginginkan bentuk (n, D).

Griddata menghitung satu nilai per titik di grid, dengan metode yang telah ditentukan. Saya memilih "terdekat" - titik grid kosong akan diisi dengan nilai dari tetangga terdekat. Ini terlihat seolah-olah area dengan informasi yang lebih sedikit memiliki sel yang lebih besar (bahkan jika bukan itu masalahnya). Seseorang dapat memilih untuk menginterpolasi "linear", maka area dengan informasi yang lebih sedikit terlihat kurang tajam. Soal rasa, kok.

points = np.array([x, y]).T # because griddata wants it that way
z_grid2 = griddata(points, z, grid, method='nearest')
# you get a 1D vector as result. Reshape to picture format!
z_grid2 = z_grid2.reshape(xx.shape[0], yy.shape[0])

Dan hop, kami serahkan ke matplotlib untuk menampilkan plot

fig = plt.figure(1, figsize=(10, 10))
ax1 = fig.add_subplot(111)
ax1.imshow(z_grid2, extent=[x_min, x_max,y_min, y_max,  ],
            origin='lower', cmap=cm.magma)
ax1.set_title("SVC: empty spots filled by nearest neighbours")
ax1.set_xlabel('log gamma')
ax1.set_ylabel('log C')
plt.show()

Di sekitar bagian runcing dari V-Shape, Anda tahu saya melakukan banyak perhitungan selama pencarian saya untuk sweet spot, sedangkan bagian yang kurang menarik hampir di tempat lain memiliki resolusi lebih rendah.

Heatmap dari SVC dalam resolusi tinggi


Bisakah Anda meningkatkan jawaban Anda untuk memiliki kode yang lengkap dan bisa dijalankan? Ini adalah metode menarik yang Anda berikan. Saya mencoba untuk lebih memahaminya saat ini. Saya tidak begitu mengerti mengapa ada bentuk V juga. Terima kasih.
ldmtwo

V-Shape berasal dari data saya. Ini adalah nilai f1 untuk SVM yang terlatih: Ini akan sedikit dalam teori SVM. Jika Anda memiliki C tinggi, itu termasuk semua poin Anda dalam perhitungan, memungkinkan rentang gamma yang lebih luas untuk bekerja. Gamma adalah kekakuan dari kurva yang memisahkan baik dan buruk. Kedua nilai tersebut harus diberikan kepada SVM (X dan Y dalam grafik saya); maka Anda mendapatkan hasilnya (Z dalam grafik saya). Di area terbaik, semoga Anda mencapai ketinggian yang berarti.
Anderas

percobaan kedua: V-Shape ada di data saya. Ini adalah nilai f1 untuk SVM: Jika Anda memiliki C tinggi, itu mencakup semua poin Anda dalam perhitungan, memungkinkan rentang gamma yang lebih luas untuk bekerja, tetapi membuat perhitungan menjadi lambat. Gamma adalah kekakuan dari kurva yang memisahkan baik dan buruk. Kedua nilai tersebut harus diberikan kepada SVM (X dan Y dalam grafik saya); maka Anda mendapatkan hasilnya (Z dalam grafik saya). Di area yang dioptimalkan Anda mendapatkan nilai tinggi, di tempat lain nilai rendah. Apa yang saya tunjukkan di sini dapat digunakan jika Anda memiliki nilai Z untuk beberapa (X, Y) dan banyak celah di tempat lain. Jika Anda memiliki (X, Y, Z) titik data, Anda dapat menggunakan kode saya.
Anderas

4

Inilah pendekatan tetangga terdekat Jurgy yang hebat tetapi diimplementasikan menggunakan scipy.cKDTree . Dalam pengujian saya sekitar 100x lebih cepat.

masukkan deskripsi gambar di sini

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.spatial import cKDTree


def data_coord2view_coord(p, resolution, pmin, pmax):
    dp = pmax - pmin
    dv = (p - pmin) / dp * resolution
    return dv


n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)

resolution = 250

extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
xv = data_coord2view_coord(xs, resolution, extent[0], extent[1])
yv = data_coord2view_coord(ys, resolution, extent[2], extent[3])


def kNN2DDens(xv, yv, resolution, neighbours, dim=2):
    """
    """
    # Create the tree
    tree = cKDTree(np.array([xv, yv]).T)
    # Find the closest nnmax-1 neighbors (first entry is the point itself)
    grid = np.mgrid[0:resolution, 0:resolution].T.reshape(resolution**2, dim)
    dists = tree.query(grid, neighbours)
    # Inverse of the sum of distances to each grid point.
    inv_sum_dists = 1. / dists[0].sum(1)

    # Reshape
    im = inv_sum_dists.reshape(resolution, resolution)
    return im


fig, axes = plt.subplots(2, 2, figsize=(15, 15))
for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 63]):

    if neighbours == 0:
        ax.plot(xs, ys, 'k.', markersize=5)
        ax.set_aspect('equal')
        ax.set_title("Scatter Plot")
    else:

        im = kNN2DDens(xv, yv, resolution, neighbours)

        ax.imshow(im, origin='lower', extent=extent, cmap=cm.Blues)
        ax.set_title("Smoothing over %d neighbours" % neighbours)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])

plt.savefig('new.png', dpi=150, bbox_inches='tight')

1
Saya tahu implementasi saya sangat tidak efisien tetapi tidak tahu tentang cKDTree. Sudah selesai dilakukan dengan baik! Saya akan merujuk Anda dalam jawaban saya.
Jurg

2

Buat array 2 dimensi yang sesuai dengan sel-sel di gambar akhir Anda, disebut say heatmap_cellsdan instantiate sebagai semua nol.

Pilih dua faktor penskalaan yang menentukan perbedaan antara setiap elemen array dalam unit nyata, untuk setiap dimensi, katakan x_scaledan y_scale. Pilih ini sedemikian rupa sehingga semua titik data Anda akan berada dalam batas array peta panas.

Untuk setiap titik data mentah dengan x_valuedan y_value:

heatmap_cells[floor(x_value/x_scale),floor(y_value/y_scale)]+=1


1

masukkan deskripsi gambar di sini

Inilah yang saya buat pada set 1 Juta titik dengan 3 kategori (berwarna Merah, Hijau, dan Biru). Berikut ini tautan ke repositori jika Anda ingin mencoba fungsinya. Github Repo

histplot(
    X,
    Y,
    labels,
    bins=2000,
    range=((-3,3),(-3,3)),
    normalize_each_label=True,
    colors = [
        [1,0,0],
        [0,1,0],
        [0,0,1]],
    gain=50)

0

Sangat mirip dengan jawaban @ Piti , tetapi menggunakan 1 panggilan alih-alih 2 untuk menghasilkan poin:

import numpy as np
import matplotlib.pyplot as plt

pts = 1000000
mean = [0.0, 0.0]
cov = [[1.0,0.0],[0.0,1.0]]

x,y = np.random.multivariate_normal(mean, cov, pts).T
plt.hist2d(x, y, bins=50, cmap=plt.cm.jet)
plt.show()

Keluaran:

2d_gaussian_heatmap


0

Saya khawatir saya sedikit terlambat ke pesta tetapi saya memiliki pertanyaan serupa beberapa waktu lalu. Jawaban yang diterima (oleh @ptomato) membantu saya, tetapi saya juga ingin memposting ini jika itu berguna untuk seseorang.


''' I wanted to create a heatmap resembling a football pitch which would show the different actions performed '''

import numpy as np
import matplotlib.pyplot as plt
import random

#fixing random state for reproducibility
np.random.seed(1234324)

fig = plt.figure(12)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

#Ratio of the pitch with respect to UEFA standards 
hmap= np.full((6, 10), 0)
#print(hmap)

xlist = np.random.uniform(low=0.0, high=100.0, size=(20))
ylist = np.random.uniform(low=0.0, high =100.0, size =(20))

#UEFA Pitch Standards are 105m x 68m
xlist = (xlist/100)*10.5
ylist = (ylist/100)*6.5

ax1.scatter(xlist,ylist)

#int of the co-ordinates to populate the array
xlist_int = xlist.astype (int)
ylist_int = ylist.astype (int)

#print(xlist_int, ylist_int)

for i, j in zip(xlist_int, ylist_int):
    #this populates the array according to the x,y co-ordinate values it encounters 
    hmap[j][i]= hmap[j][i] + 1   

#Reversing the rows is necessary 
hmap = hmap[::-1]

#print(hmap)
im = ax2.imshow(hmap)

Inilah hasilnya masukkan deskripsi gambar di sini

Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.