Memahami ide numpy.einsum()
sangat mudah jika Anda memahaminya secara intuitif. Sebagai contoh, mari kita mulai dengan deskripsi sederhana yang melibatkan perkalian matriks .
Untuk menggunakannya numpy.einsum()
, yang harus Anda lakukan adalah meneruskan string subkrip yang disebut sebagai argumen, diikuti oleh array input Anda .
Katakanlah Anda memiliki dua array 2D, A
dan B
, dan Anda ingin melakukan perkalian matriks. Jadi, Anda lakukan:
np.einsum("ij, jk -> ik", A, B)
Di sini string subskrip ij
berhubungan dengan array A
sedangkan string subskrip jk
berhubungan dengan array B
. Juga, hal yang paling penting untuk dicatat di sini adalah bahwa jumlah karakter dalam setiap string subskrip harus sesuai dengan dimensi array. (yaitu dua karakter untuk array 2D, tiga karakter untuk array 3D, dan sebagainya.) Dan jika Anda mengulangi karakter di antara string subskrip ( j
dalam kasus kami), maka itu berarti Anda ingin ein
jumlah terjadi di sepanjang dimensi tersebut. Dengan demikian, jumlah tersebut akan dikurangi. (Yaitu dimensi itu akan hilang )
The String subscript setelah ini ->
, akan array yang dihasilkan kami. Jika Anda membiarkannya kosong, maka semuanya akan dijumlahkan dan nilai skalar dikembalikan sebagai hasilnya. Lain array yang dihasilkan akan memiliki dimensi sesuai dengan string subskrip . Dalam contoh kita, itu akan menjadi ik
. Ini intuitif karena kita tahu bahwa untuk perkalian matriks jumlah kolom dalam array A
harus cocok dengan jumlah baris dalam array B
yang merupakan apa yang terjadi di sini (yaitu kita menyandikan pengetahuan ini dengan mengulangi char j
dalam string subskrip )
Berikut adalah beberapa contoh yang menggambarkan penggunaan / kekuatan np.einsum()
dalam mengimplementasikan beberapa operasi tensor atau nd-array yang umum , secara ringkas.
Input
# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])
# an array
In [198]: A
Out[198]:
array([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
# another array
In [199]: B
Out[199]:
array([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
1) Perkalian matriks (mirip dengan np.matmul(arr1, arr2)
)
In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]:
array([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
2) Ekstrak elemen di sepanjang main-diagonal (mirip dengan np.diag(arr)
)
In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])
3) Produk Hadamard (yaitu produk elemen-bijaksana dari dua array) (mirip dengan arr1 * arr2
)
In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]:
array([[ 11, 12, 13, 14],
[ 42, 44, 46, 48],
[ 93, 96, 99, 102],
[164, 168, 172, 176]])
4) Elemen-bijaksana kuadrat (mirip dengan np.square(arr)
atau arr ** 2
)
In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]:
array([[ 1, 1, 1, 1],
[ 4, 4, 4, 4],
[ 9, 9, 9, 9],
[16, 16, 16, 16]])
5) Jejak (yaitu jumlah elemen main-diagonal) (mirip dengan np.trace(arr)
)
In [217]: np.einsum("ii -> ", A)
Out[217]: 110
6) Matriks transpose (mirip dengan np.transpose(arr)
)
In [221]: np.einsum("ij -> ji", A)
Out[221]:
array([[11, 21, 31, 41],
[12, 22, 32, 42],
[13, 23, 33, 43],
[14, 24, 34, 44]])
7) Produk Luar (dari vektor) (mirip dengan np.outer(vec1, vec2)
)
In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]:
array([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
8) Produk Dalam (dari vektor) (mirip dengan np.inner(vec1, vec2)
)
In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14
9) Jumlah sepanjang sumbu 0 (mirip dengan np.sum(arr, axis=0)
)
In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])
10) Jumlahkan sepanjang sumbu 1 (mirip dengan np.sum(arr, axis=1)
)
In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4, 8, 12, 16])
11) Penggandaan Matriks Batch
In [287]: BM = np.stack((A, B), axis=0)
In [288]: BM
Out[288]:
array([[[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]],
[[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3],
[ 4, 4, 4, 4]]])
In [289]: BM.shape
Out[289]: (2, 4, 4)
# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)
In [293]: BMM
Out[293]:
array([[[1350, 1400, 1450, 1500],
[2390, 2480, 2570, 2660],
[3430, 3560, 3690, 3820],
[4470, 4640, 4810, 4980]],
[[ 10, 10, 10, 10],
[ 20, 20, 20, 20],
[ 30, 30, 30, 30],
[ 40, 40, 40, 40]]])
In [294]: BMM.shape
Out[294]: (2, 4, 4)
12) Jumlah sepanjang sumbu 2 (mirip dengan np.sum(arr, axis=2)
)
In [330]: np.einsum("ijk -> ij", BM)
Out[330]:
array([[ 50, 90, 130, 170],
[ 4, 8, 12, 16]])
13) Jumlahkan semua elemen dalam array (mirip dengan np.sum(arr)
)
In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480
14) Jumlah lebih dari beberapa sumbu (yaitu marginalisasi)
(mirip dengan np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7))
)
# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))
# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)
# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))
In [365]: np.allclose(esum, nsum)
Out[365]: True
15) Produk Dot Ganda (mirip dengan np.sum (produk hadamard) lih. 3 )
In [772]: A
Out[772]:
array([[1, 2, 3],
[4, 2, 2],
[2, 3, 4]])
In [773]: B
Out[773]:
array([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124
16) penggandaan array 2D dan 3D
Penggandaan seperti itu bisa sangat berguna ketika menyelesaikan sistem persamaan linear ( Ax = b ) di mana Anda ingin memverifikasi hasilnya.
# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)
# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)
# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)
# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True
Sebaliknya, jika seseorang harus menggunakan np.matmul()
verifikasi ini, kita harus melakukan beberapa reshape
operasi untuk mencapai hasil yang sama seperti:
# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)
# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True
Bonus : Baca lebih banyak matematika di sini: Einstein-Summation dan pasti di sini: Tensor-Notation
(A * B)^T
, atau setaraB^T * A^T
.