Bagaimana cara mendaftar semua operasi yang digunakan di Tensorflow SavedModel?


10

Jika saya menyimpan model saya menggunakan tensorflow.saved_model.savefungsi dalam format SavedModel, bagaimana saya bisa mengambil Tensorflow Ops mana yang digunakan dalam model ini sesudahnya. Karena model dapat dipulihkan, operasi ini disimpan dalam grafik, tebakan saya ada di saved_model.pbfile. Jika saya memuat protobuf ini (jadi bukan seluruh model) bagian perpustakaan dari protobuf mencantumkan ini, tetapi ini tidak didokumentasikan dan ditandai sebagai fitur eksperimental untuk saat ini. Model yang dibuat di Tensorflow 1.x tidak akan memiliki bagian ini.

Jadi apa cara cepat dan andal untuk mengambil daftar Operasi yang digunakan (Suka MatchingFilesatau WriteFile) dari model dalam format SavedModel?

Saat ini aku bisa membekukan semuanya, seperti tensorflowjs-converter halnya. Karena mereka juga memeriksa Operasi yang didukung. Ini saat ini tidak berfungsi ketika LSTM ada dalam model, lihat di sini . Apakah ada cara yang lebih baik untuk melakukan ini, karena Ops pasti ada di sana?

Contoh model:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Diharapkan dalam output semua Ops, mengandung dalam hal ini setidaknya:

  • ReadFile seperti yang dijelaskan di sini
  • ...

1
Sulit untuk mengatakan dengan tepat apa yang Anda inginkan, apakah saved_model.pbitu tf.GraphDef, atau SavedModelpesan protobuf? Jika Anda memiliki tf.GraphDefdisebut gd, Anda bisa mendapatkan daftar ops digunakan dengan sorted(set(n.op for n in gd.node)). Jika Anda memiliki model yang dimuat, Anda dapat melakukannya sorted(set(op.type for op in tf.get_default_graph().get_operations())). Jika itu adalah SavedModel, Anda dapat memperolehnya tf.GraphDef(mis saved_model.meta_graphs[0].graph_def.).
jdehesa

Saya ingin mengambil ops dari SavedModel yang tersimpan. Jadi memang, opsi terakhir yang Anda gambarkan. Apa saved_modelvariabel dalam contoh terakhir Anda? Hasil tf.saved_model.load('/path/to/model')atau memuat protobuf dari file Saved_model.pb.
sampers

Jawaban:


1

Jika saved_model.pbada SavedModelpesan protobuf, maka Anda mendapatkan operasi langsung dari sana. Katakanlah kita membuat model sebagai berikut:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Kami sekarang dapat menemukan operasi yang digunakan oleh model seperti ini:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

Saya mencoba sesuatu seperti ini, tetapi sayangnya ini tidak seperti yang saya harapkan: Katakanlah saya memiliki model yang melakukan ini: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')Kemudian ReadFile Op seperti yang tercantum di sini ada di sana, tetapi tidak dicetak.
sampers

1
@amper Saya telah mengedit jawaban dengan contoh seperti yang Anda sarankan. Saya mendapatkan ReadFileoperasi di output. Apakah mungkin bahwa, dalam kasus Anda yang sebenarnya, operasi itu tidak antara input dan output dari model yang disimpan? Dalam hal ini saya pikir mungkin akan dipangkas.
jdehesa

Memang dengan model yang diberikan itu berhasil. Sayangnya untuk modul yang dibuat di tf2, tidak. Jika saya membuat tf.Module dengan 1 fungsi dengan anotasi file_nameargumen @tf.function, berisi panggilan yang saya sebutkan di komentar saya sebelumnya, itu memberikan daftar berikut:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers

menambahkan model ke pertanyaan saya
sampers

@amper Saya telah memperbarui jawaban saya. Saya menggunakan TF 1.x sebelumnya, saya tidak terbiasa dengan perubahan pada objek definisi grafik di TF 2.x, saya pikir jawabannya sekarang mencakup semua yang ada di model yang disimpan. Saya pikir operasi yang sesuai dengan fungsi Python yang Anda tulis berada di saved_model.meta_graphs[0].graph_def.library.function[0]( node_defkoleksi dalam objek fungsi itu).
jdehesa
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.