Apa itu aktivasi GELU?


18

Saya menggunakan kertas BERT yang menggunakan GELU (Gaussian Error Linear Unit) yang menyatakan persamaan sebagai yang

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])
Bisakah Anda menyederhanakan persamaan dan menjelaskan bagaimana persamaannya telah diperkirakan.

Jawaban:


18

Fungsi GELU

Kita dapat memperluas distribusi kumulatif dariN(0,1) , yaitu , sebagai berikut: Φ(x)

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

Perhatikan bahwa ini adalah definisi , bukan persamaan (atau hubungan). Penulis telah memberikan beberapa pembenaran untuk proposal ini, misalnya analogi stokastik , namun secara matematis, ini hanyalah sebuah definisi.

Berikut ini alur cerita GELU:

Perkiraan Tanh

Untuk jenis perkiraan numerik ini, ide kuncinya adalah menemukan fungsi yang serupa (terutama berdasarkan pengalaman), membuat parameterisasi, dan kemudian memasangnya ke satu set poin dari fungsi aslinya.

Mengetahui bahwa sangat dekat denganerf(x)tanh(x)

dan turunan pertama dari bertepatan dengan di , yang merupakan , kami melanjutkan untuk mencocokkan (atau dengan lebih banyak istilah) ke sekumpulan titik .erf(x2)tanh(2πx)x=02π

tanh(2π(x+Sebuahx2+bx3+cx4+dx5))
(xsaya,erf(xsaya2))

Saya telah memasang fungsi ini ke 20 sampel antara ( menggunakan situs ini ), dan berikut adalah koefisiennya:(-1.5,1.5)

Dengan mengatur ,Sebuah=c=d=0b diperkirakan 0,04495641 . Dengan lebih banyak sampel dari rentang yang lebih luas (situs itu hanya diizinkan 20), koefisien b akan lebih dekat dengan kertas . Akhirnya kita dapatkan0,0447150,04495641b0,044715

GELU(x)=xΦ(x)=0,5x(1+erf(x2))0,5x(1+tanh(2π(x+0,044715x3)))

dengan mean squared error 10-8 untuk x[-10,10] .

Perhatikan bahwa jika kami tidak memanfaatkan hubungan antara derivatif pertama, istilah 2π akan dimasukkan dalam parameter sebagai berikut

0.5x(1+tanh(0.797885x+0.035677x3))
yang kurang indah (kurang analitis, lebih numerik)!

Memanfaatkan paritas

erff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

x2x40.23x20x2

Perkiraan sigmoid

erf(x)2(σ(x)12)104x[10,10]

Berikut adalah kode Python untuk menghasilkan titik data, menyesuaikan fungsi, dan menghitung rata-rata kesalahan kuadrat:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Keluaran:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
Mengapa dibutuhkan pendekatan? Tidak bisakah mereka menggunakan fungsi erf saja?
SebiSebi

8

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
a0.044715

x[1,1]x

tanh(x)=xx33+o(x3)
erf(x)=2π(xx33)+o(x3).
tanh(2π(x+Sebuahx3))=2π(x+(Sebuah-23π)x3)+Hai(x3)
erf(x2)=2π(x-x36)+Hai(x3).
x3
Sebuah0,04553992412
0,044715

Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.