Perbedaan antara numpy dot () dan perkalian matriks Python 3.5+ @


119

Saya baru saja pindah ke Python 3.5 dan melihat operator perkalian matriks baru (@) terkadang berperilaku berbeda dari operator numpy dot . Misalnya, untuk array 3d:

import numpy as np

a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b  # Python 3.5+
d = np.dot(a, b)

The @Operator mengembalikan array bentuk:

c.shape
(8, 13, 13)

sementara np.dot()fungsinya kembali:

d.shape
(8, 13, 8, 13)

Bagaimana cara mereproduksi hasil yang sama dengan numpy dot? Apakah ada perbedaan signifikan lainnya?


5
Anda tidak bisa mendapatkan hasil itu dari titik. Saya pikir orang pada umumnya setuju bahwa penanganan dot pada input berdimensi tinggi adalah keputusan desain yang salah.
user2357112 mendukung Monica

Mengapa mereka tidak mengimplementasikan matmulfungsi tersebut bertahun-tahun yang lalu? @sebagai operator infix baru, tetapi fungsinya bekerja dengan baik tanpanya.
hpaulj

Jawaban:


140

The @Operator menyebut array __matmul__metode, tidak dot. Metode ini juga ada di API sebagai fungsinya np.matmul.

>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)

Dari dokumentasi:

matmulberbeda dotdalam dua hal penting.

  • Perkalian dengan skalar tidak diperbolehkan.
  • Tumpukan matriks disiarkan bersama seolah-olah matriks tersebut adalah elemen.

Poin terakhir memperjelas bahwa dotdan matmulmetode berperilaku berbeda ketika melewati array 3D (atau dimensi yang lebih tinggi). Mengutip dari dokumentasi lagi:

Untuk matmul:

Jika salah satu argumennya adalah ND, N> 2, argumen tersebut akan diperlakukan sebagai tumpukan matriks yang berada di dua indeks terakhir dan disiarkan sesuai dengan itu.

Untuk np.dot:

Untuk larik 2-D ekuivalen dengan perkalian matriks, dan untuk larik 1-D menjadi hasil kali dalam vektor (tanpa konjugasi kompleks). Untuk dimensi N, ini adalah hasil penjumlahan dari sumbu terakhir a dan detik ke akhir dari b


13
Kebingungan di sini mungkin karena catatan rilis, yang secara langsung menyamakan simbol "@" dengan fungsi titik () dari numpy dalam kode contoh.
Alex K

13

Jawaban oleh @ajcr menjelaskan bagaimana dotdan matmul(dipanggil oleh @simbol) berbeda. Dengan melihat contoh sederhana, kita dapat melihat dengan jelas bagaimana keduanya berperilaku berbeda saat beroperasi pada 'tumpukan matriks' atau tensor.

Untuk memperjelas perbedaan, ambil larik 4x4 dan kembalikan dotproduk dan matmulproduk dengan 'tumpukan matriks' atau tensor 3x4x2.

import numpy as np
fourbyfour = np.array([
                       [1,2,3,4],
                       [3,2,1,4],
                       [5,4,6,7],
                       [11,12,13,14]
                      ])


threebyfourbytwo = np.array([
                             [[2,3],[11,9],[32,21],[28,17]],
                             [[2,3],[1,9],[3,21],[28,7]],
                             [[2,3],[1,9],[3,21],[28,7]],
                            ])

print('4x4*3x4x2 dot:\n {}\n'.format(np.dot(fourbyfour,twobyfourbythree)))
print('4x4*3x4x2 matmul:\n {}\n'.format(np.matmul(fourbyfour,twobyfourbythree)))

Produk dari setiap operasi muncul di bawah. Perhatikan bagaimana perkalian titiknya,

... hasil penjumlahan dari sumbu terakhir a dan detik ke akhir dari b

dan bagaimana produk matriks dibentuk dengan menyiarkan matriks bersama-sama.

4x4*3x4x2 dot:
 [[[232 152]
  [125 112]
  [125 112]]

 [[172 116]
  [123  76]
  [123  76]]

 [[442 296]
  [228 226]
  [228 226]]

 [[962 652]
  [465 512]
  [465 512]]]

4x4*3x4x2 matmul:
 [[[232 152]
  [172 116]
  [442 296]
  [962 652]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]]

2
titik (a, b) [i, j, k, m] = jumlah (a [i, j ,:] * b [k,:, m]) ------- seperti dokumentasi mengatakan: itu adalah jumlah produk di atas sumbu terakhir dari a dan sumbu kedua hingga terakhir dari b:
Ronak Agrawal

Tangkapan yang bagus Namun, ini adalah 3x4x2. Cara lain untuk membangun matriks adalah dengan a = np.arange(24).reshape(3, 4, 2)membuat array dengan dimensi 3x4x2.
Nathan

8

FYI saja, @dan numpy-nya setara dotdan matmulsemuanya kira-kira sama cepatnya. (Plot dibuat dengan perfplot , proyek saya.)

masukkan deskripsi gambar di sini

Kode untuk mereproduksi plot:

import perfplot
import numpy


def setup(n):
    A = numpy.random.rand(n, n)
    x = numpy.random.rand(n)
    return A, x


def at(data):
    A, x = data
    return A @ x


def numpy_dot(data):
    A, x = data
    return numpy.dot(A, x)


def numpy_matmul(data):
    A, x = data
    return numpy.matmul(A, x)


perfplot.show(
    setup=setup,
    kernels=[at, numpy_dot, numpy_matmul],
    n_range=[2 ** k for k in range(12)],
    logx=True,
    logy=True,
)

7

Dalam matematika, menurut saya titik di numpy lebih masuk akal

titik (a, b) _ {i, j, k, a, b, c} =rumus

karena memberikan perkalian titik jika a dan b adalah vektor, atau perkalian matriks jika a dan b adalah matriks


Adapun operasi matmul dalam numpy terdiri dari bagian-bagian hasil titik , dan dapat didefinisikan sebagai

> matmul (a, b) _ {i, j, k, c} =rumus

Jadi, Anda dapat melihat bahwa matmul (a, b) mengembalikan array dengan bentuk kecil, yang memiliki konsumsi memori lebih kecil dan lebih masuk akal dalam aplikasi. Secara khusus, menggabungkan dengan penyiaran , Anda bisa mendapatkan

matmul (a, b) _ {i, j, k, l} =rumus

sebagai contoh.


Dari dua definisi di atas, Anda dapat melihat persyaratan untuk menggunakan kedua operasi tersebut. Asumsikan a.shape = (s1, s2, s3, s4) and b.shape = (t1, t2, t3, t4)

  • Untuk menggunakan titik (a, b) yang Anda butuhkan

    1. t3 = s4 ;
  • Untuk menggunakan matmul (a, b) yang Anda butuhkan

    1. t3 = s4
    2. t2 = s2 , atau salah satu dari t2 dan s2 adalah 1
    3. t1 = s1 , atau salah satu dari t1 dan s1 adalah 1

Gunakan potongan kode berikut untuk meyakinkan diri Anda sendiri.

Contoh kode

import numpy as np
for it in xrange(10000):
    a = np.random.rand(5,6,2,4)
    b = np.random.rand(6,4,3)
    c = np.matmul(a,b)
    d = np.dot(a,b)
    #print 'c shape: ', c.shape,'d shape:', d.shape

    for i in range(5):
        for j in range(6):
            for k in range(2):
                for l in range(3):
                    if not c[i,j,k,l] == d[i,j,k,j,l]:
                        print it,i,j,k,l,c[i,j,k,l]==d[i,j,k,j,l] #you will not see them

np.matmuljuga memberikan perkalian titik pada vektor dan produk matriks pada matriks.
Subhaneil Lahiri

2

Berikut adalah perbandingan dengan np.einsumuntuk menunjukkan bagaimana indeks diproyeksikan

np.allclose(np.einsum('ijk,ijk->ijk', a,b), a*b)        # True 
np.allclose(np.einsum('ijk,ikl->ijl', a,b), a@b)        # True
np.allclose(np.einsum('ijk,lkm->ijlm',a,b), a.dot(b))   # True

0

Pengalaman saya dengan MATMUL dan DOT

Saya terus-menerus mendapatkan "ValueError: Bentuk nilai yang diteruskan adalah (200, 1), indeks menyiratkan (200, 3)" saat mencoba menggunakan MATMUL. Saya ingin solusi cepat dan menemukan DOT memberikan fungsi yang sama. Saya tidak mendapatkan kesalahan apa pun saat menggunakan DOT. Saya mendapatkan jawaban yang benar

dengan MATMUL

X.shape
>>>(200, 3)

type(X)

>>>pandas.core.frame.DataFrame

w

>>>array([0.37454012, 0.95071431, 0.73199394])

YY = np.matmul(X,w)

>>>  ValueError: Shape of passed values is (200, 1), indices imply (200, 3)"

dengan DOT

YY = np.dot(X,w)
# no error message
YY
>>>array([ 2.59206877,  1.06842193,  2.18533396,  2.11366346,  0.28505879, 

YY.shape

>>> (200, )
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.