tl; dr Meskipun ini adalah kumpulan data klasifikasi gambar, ini tetap merupakan tugas yang sangat mudah , yang dengannya seseorang dapat dengan mudah menemukan pemetaan langsung dari input ke prediksi.
Menjawab:
Ini adalah pertanyaan yang sangat menarik dan berkat kesederhanaan regresi logistik Anda benar-benar dapat menemukan jawabannya.
Apa yang dilakukan regresi logistik adalah agar setiap gambar menerima input dan mengalikannya dengan bobot untuk menghasilkan prediksi. Yang menarik adalah karena pemetaan langsung antara input dan output (yaitu tidak ada lapisan tersembunyi), nilai setiap bobot sesuai dengan seberapa banyak masing-masing dari input diperhitungkan saat menghitung probabilitas setiap kelas. Sekarang, dengan mengambil bobot untuk setiap kelas dan membentuknya kembali menjadi (yaitu resolusi gambar), kita dapat mengetahui piksel apa yang paling penting untuk perhitungan setiap kelas .78478428×28
Perhatikan, sekali lagi, bahwa ini adalah bobotnya .
Sekarang lihat gambar di atas dan fokus pada dua digit pertama (yaitu nol dan satu). Bobot biru berarti bahwa intensitas piksel ini banyak berkontribusi untuk kelas itu dan nilai merah berarti memberi kontribusi negatif.
Sekarang bayangkan, bagaimana seseorang menggambar angka ? Dia menggambar bentuk melingkar yang kosong di antaranya. Itulah tepatnya yang diangkat oleh beban. Bahkan jika seseorang menggambar tengah gambar, itu dihitung negatif sebagai nol. Jadi untuk mengenali nol Anda tidak perlu beberapa filter canggih dan fitur tingkat tinggi. Anda bisa melihat lokasi piksel yang diambil dan menilai berdasarkan ini.0
Hal yang sama untuk . Itu selalu memiliki garis vertikal lurus di tengah gambar. Semua yang lain terhitung negatif.1
Sisa digitnya sedikit lebih rumit, tetapi dengan sedikit imajinasi Anda dapat melihat , , dan . Angka-angka lainnya sedikit lebih sulit, yang sebenarnya membatasi regresi logistik untuk mencapai tahun 90-an.2378
Melalui ini Anda dapat melihat bahwa regresi logistik memiliki peluang yang sangat baik untuk mendapatkan banyak gambar dengan benar dan itulah mengapa nilainya sangat tinggi.
Kode untuk mereproduksi gambar di atas sedikit bertanggal, tetapi di sini Anda mulai:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b
y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #
correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Train model
batch_size = 64
with tf.Session() as sess:
loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []
sess.run(tf.global_variables_initializer())
for step in range(1, 1001):
x_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})
l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
loss_tr.append(l_tr)
acc_tr.append(a_tr)
loss_ts.append(l_ts)
acc_ts.append(a_ts)
weights = sess.run(W)
print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# Plotting:
for i in range(10):
plt.subplot(2, 5, i+1)
weight = weights[:,i].reshape([28,28])
plt.title(i)
plt.imshow(weight, cmap='RdBu') # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
frame1.axes.get_yaxis().set_visible(False)