Saya melatih jaringan saraf convolutional sederhana untuk regresi, di mana tugasnya adalah untuk memprediksi lokasi (x, y) kotak dalam gambar, misalnya:
Output dari jaringan memiliki dua node, satu untuk x, dan satu untuk y. Sisa dari jaringan adalah jaringan saraf convolutional standar. Hilangnya adalah standar kuadrat kesalahan antara posisi kotak yang diprediksi, dan posisi kebenaran dasar. Saya melatih 10.000 gambar ini, dan memvalidasi pada tahun 2000.
Masalah yang saya alami, adalah bahwa bahkan setelah pelatihan yang signifikan, kerugiannya tidak benar-benar berkurang. Setelah mengamati output dari jaringan, saya perhatikan bahwa jaringan cenderung nilai output mendekati nol, untuk kedua node output. Dengan demikian, prediksi lokasi kotak selalu menjadi pusat gambar. Ada beberapa penyimpangan dalam prediksi, tetapi selalu di sekitar nol. Di bawah ini menunjukkan kerugian:
Saya telah menjalankan ini selama lebih banyak zaman daripada yang ditunjukkan dalam grafik ini, dan kerugiannya tidak pernah berkurang. Menariknya di sini, kerugian sebenarnya meningkat pada satu titik.
Jadi, tampaknya jaringan hanya memprediksi rata-rata data pelatihan, daripada mempelajari kecocokan. Ada ide mengapa ini bisa terjadi? Saya menggunakan Adam sebagai pengoptimal, dengan tingkat pembelajaran awal 0,01, dan relus sebagai aktivasi
Jika Anda tertarik pada beberapa kode saya (Keras), di bawah ini:
# Create the model
model = Sequential()
model.add(Convolution2D(32, 5, 5, border_mode='same', subsample=(2, 2), activation='relu', input_shape=(3, image_width, image_height)))
model.add(Convolution2D(64, 5, 5, border_mode='same', subsample=(2, 2), activation='relu'))
model.add(Convolution2D(128, 5, 5, border_mode='same', subsample=(2, 2), activation='relu'))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(2, activation='linear'))
# Compile the model
adam = Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='mean_squared_error', optimizer=adam)
# Fit the model
model.fit(images, targets, batch_size=128, nb_epoch=1000, verbose=1, callbacks=[plot_callback], validation_split=0.2, shuffle=True)