TensorFlow, mengapa ada 3 file setelah menyimpan model?


113

Setelah membaca dokumen , saya menyimpan model TensorFlow, berikut kode demo saya:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

tapi setelah itu, saya temukan ada 3 file

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

Dan saya tidak dapat memulihkan model dengan memulihkan model.ckptfile, karena tidak ada file seperti itu. Ini kode saya

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Jadi kenapa ada 3 file?


2
Apakah Anda mengetahui cara mengatasi ini? Bagaimana cara memuat model lagi (menggunakan Keras)?
rajkiran

Jawaban:


116

Coba ini:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

Metode penyimpanan TensorFlow menyimpan tiga jenis file karena menyimpan struktur grafik secara terpisah dari nilai variabel . The .metaFile menggambarkan struktur grafik disimpan, sehingga Anda perlu mengimpor sebelum memulihkan pos pemeriksaan (jika tidak tidak tahu apa variabel nilai pos pemeriksaan disimpan sesuai dengan).

Atau, Anda dapat melakukan ini:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Meskipun tidak ada file bernama model.ckpt, Anda masih merujuk ke pos pemeriksaan yang disimpan dengan nama itu saat memulihkannya. Dari saver.pykode sumber :

Pengguna hanya perlu berinteraksi dengan awalan yang ditentukan pengguna ... alih-alih nama jalur fisik apa pun.


1
jadi .index dan .data tidak digunakan? Lalu, kapan 2 file itu digunakan?
ajfbiw.s

26
@ ajfbiw.s .meta menyimpan struktur grafik, .data menyimpan nilai dari setiap variabel dalam grafik, .index mengidentifikasi checkpiont. Jadi dalam contoh di atas: import_meta_graph menggunakan .meta, dan saver.restore menggunakan .data dan .index
TK Bartel

Oh begitu. Terima kasih.
ajfbiw.s

1
Adakah kemungkinan Anda menyimpan model dengan versi TensorFlow yang berbeda dari yang Anda gunakan untuk memuatnya? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel

5
Apakah ada yang tahu apa artinya 00000dan 00001angka? dalam variables.data-?????-of-?????file
Ivan Talalaev

55
  • file meta : menjelaskan struktur grafik yang disimpan, termasuk GraphDef, SaverDef, dan sebagainya; kemudian terapkan tf.train.import_meta_graph('/tmp/model.ckpt.meta'), akan memulihkan Saverdan Graph.

  • file indeks : itu adalah tabel tetap string-string (tensorflow :: table :: Table). Setiap kunci adalah nama tensor dan nilainya adalah BundleEntryProto berseri. Setiap BundleEntryProto mendeskripsikan metadata tensor: file "data" mana yang berisi konten tensor, offset ke file tersebut, checksum, beberapa data tambahan, dll.

  • file data : ini adalah koleksi TensorBundle, simpan nilai semua variabel.


Saya telah mendapatkan file pb yang saya miliki untuk klasifikasi gambar. Bisakah saya menggunakannya untuk klasifikasi video waktu nyata?

Bisakah Anda memberi tahu saya, Menggunakan Keras 2, bagaimana cara memuat model jika disimpan sebagai 3 file?
rajkiran

5

Saya memulihkan embeddings kata yang terlatih dari tutorial tensorflow Word2Vec.

Jika Anda telah membuat beberapa pos pemeriksaan:

misalnya file yang dibuat terlihat seperti ini

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

coba ini

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

saat memanggil restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

Apa yang dimaksud dengan "00000-of-00001" dalam "model.ckpt-55695.data-00000-of-00001"?
hafiz031

0

Jika Anda melatih CNN dengan putus sekolah, misalnya, Anda dapat melakukan ini:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.