TensorFlow menyimpan / memuat grafik dari file


98

Dari apa yang telah saya kumpulkan sejauh ini, ada beberapa cara berbeda untuk membuang grafik TensorFlow ke dalam file dan kemudian memuatnya ke program lain, tetapi saya belum dapat menemukan contoh / informasi yang jelas tentang cara kerjanya. Yang sudah saya ketahui adalah ini:

  1. Simpan variabel model ke dalam file checkpoint (.ckpt) menggunakan a tf.train.Saver()dan pulihkan nanti ( sumber )
  2. Simpan model ke dalam file .pb dan muat kembali menggunakan tf.train.write_graph()dan tf.import_graph_def()( sumber )
  3. Memuat model dari file .pb, melatihnya kembali, dan membuangnya ke file .pb baru menggunakan Bazel ( sumber )
  4. Bekukan grafik untuk menyimpan grafik dan bobot ( sumber )
  5. Gunakan as_graph_def()untuk menyimpan model, dan untuk bobot / variabel, petakan mereka menjadi konstanta ( sumber )

Namun, saya belum dapat menjawab beberapa pertanyaan tentang metode berbeda ini:

  1. Mengenai file checkpoint, apakah mereka hanya menyimpan bobot model yang terlatih? Bisakah file checkpoint dimuat ke program baru, dan digunakan untuk menjalankan model, atau apakah mereka hanya berfungsi sebagai cara untuk menyimpan bobot dalam model pada waktu / tahap tertentu?
  2. Mengenai tf.train.write_graph(), apakah bobot / variabel juga disimpan?
  3. Terkait Bazel, dapatkah itu hanya menyimpan ke / memuat dari file .pb untuk pelatihan ulang? Apakah ada perintah Bazel sederhana hanya untuk membuang grafik ke dalam .pb?
  4. Mengenai pembekuan, dapatkah grafik beku dimuat dengan menggunakan tf.import_graph_def()?
  5. Demo Android untuk TensorFlow dimuat dalam model Inception Google dari file .pb. Jika saya ingin mengganti file .pb saya sendiri, bagaimana cara melakukannya? Apakah saya perlu mengubah kode / metode asli?
  6. Secara umum, apa sebenarnya perbedaan antara semua metode ini? Atau secara lebih luas, apa perbedaan antara as_graph_def()/.ckpt/.pb?

Singkatnya, yang saya cari adalah metode untuk menyimpan grafik (seperti dalam, berbagai operasi dan semacamnya) dan bobot / variabelnya ke dalam file, yang kemudian dapat digunakan untuk memuat grafik dan bobot ke program lain , untuk digunakan (tidak harus melanjutkan / pelatihan ulang).

Dokumentasi tentang topik ini tidak terlalu mudah, jadi jawaban / informasi apa pun akan sangat dihargai.


2
API terbaru / terlengkap adalah grafik meta, yang akan memberi Anda cara untuk menyimpan ketiganya sekaligus - 1) grafik 2) nilai parameter 3) koleksi: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Jawaban:


80

Ada banyak cara untuk mendekati masalah penyimpanan model di TensorFlow, yang bisa membuatnya sedikit membingungkan. Mengambil setiap sub-pertanyaan Anda secara bergantian:

  1. File checkpoint (dihasilkan misalnya dengan memanggil saver.save()sebuah tf.train.Saverobjek) hanya berisi bobot, dan variabel lain yang didefinisikan dalam program yang sama. Untuk menggunakannya di program lain, Anda harus membuat ulang struktur grafik terkait (misalnya dengan menjalankan kode untuk membuatnya lagi, atau memanggil tf.import_graph_def()), yang memberi tahu TensorFlow apa yang harus dilakukan dengan bobot tersebut. Perhatikan bahwa pemanggilan saver.save()juga menghasilkan file yang berisi a MetaGraphDef, yang berisi grafik dan detail tentang cara mengaitkan bobot dari checkpoint dengan grafik itu. Lihat tutorial untuk lebih jelasnya.

  2. tf.train.write_graph()hanya menulis struktur grafik; bukan bobotnya.

  3. Bazel tidak terkait dengan membaca atau menulis grafik TensorFlow. (Mungkin saya salah memahami pertanyaan Anda: silakan klarifikasi dalam komentar.)

  4. Grafik beku dapat dimuat menggunakan tf.import_graph_def(). Dalam kasus ini, bobot (biasanya) disematkan dalam grafik, jadi Anda tidak perlu memuat checkpoint terpisah.

  5. Perubahan utamanya adalah memperbarui nama tensor yang dimasukkan ke dalam model, dan nama tensor yang diambil dari model. Dalam demo Android TensorFlow, ini akan sesuai dengan string inputNamedan outputNameyang diteruskan ke TensorFlowClassifier.initializeTensorFlow().

  6. Ini GraphDefadalah struktur program, yang biasanya tidak berubah selama proses pelatihan. Checkpoint adalah snapshot dari status proses pelatihan, yang biasanya berubah di setiap langkah proses pelatihan. Akibatnya, TensorFlow menggunakan format penyimpanan yang berbeda untuk jenis data ini, dan API tingkat rendah menyediakan berbagai cara untuk menyimpan dan memuatnya. Pustaka tingkat yang lebih tinggi, seperti MetaGraphDefpustaka, Keras , dan skflow membangun mekanisme ini untuk menyediakan cara yang lebih nyaman untuk menyimpan dan memulihkan seluruh model.


Apakah ini berarti bahwa dokumentasi C ++ API berbohong, ketika dikatakan bahwa Anda dapat memuat grafik yang disimpan dengan tf.train.write_graph()dan kemudian menjalankannya?
mnicky

2
Dokumentasi C ++ API tidak berbohong, tetapi beberapa detail hilang. Detail terpenting adalah, selain GraphDefdisimpan oleh tf.train.write_graph(), Anda juga perlu mengingat nama tensor yang ingin Anda beri makan dan ambil saat menjalankan grafik (item 5 di atas).
Tuan

@ mrry: Saya mencoba menggunakan contoh tensorflows DeepDream. tetapi tampaknya itu membutuhkan model terlatih dalam format pb! Saya menjalankan contoh Cifar10, tetapi itu hanya membuat pos pemeriksaan! Saya tidak dapat menemukan file pb atau apapun! bagaimana cara mengubah pos pemeriksaan saya ke format pb yang digunakan contoh deepdream?
Rika

2
@ Coderx7 Saya rasa Anda tidak dapat mengubah .ckpt menjadi .pb karena pos pemeriksaan hanya berisi bobot dan variabel dan tidak tahu apa-apa tentang struktur grafik
davidivad

1
apakah ada kode sederhana untuk memuat file .pb dan kemudian menjalankannya?
Kong

1

Anda dapat mencoba kode berikut:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.