Cara terbaik untuk menyimpan model yang terlatih di PyTorch?


192

Saya mencari cara alternatif untuk menyelamatkan model yang terlatih di PyTorch. Sejauh ini, saya telah menemukan dua alternatif.

  1. torch.save () untuk menyimpan model dan torch.load () untuk memuat model.
  2. model.state_dict () untuk menyimpan model yang terlatih dan model.load_state_dict () untuk memuat model yang disimpan.

Saya telah menemukan diskusi ini di mana pendekatan 2 direkomendasikan daripada pendekatan 1.

Pertanyaan saya adalah, mengapa pendekatan kedua lebih disukai? Apakah hanya karena modul torch.nn memiliki dua fungsi tersebut dan kami didorong untuk menggunakannya?


2
Saya pikir itu karena torch.save () menyimpan semua variabel perantara juga, seperti output menengah untuk digunakan kembali propagasi. Tetapi Anda hanya perlu menyimpan parameter model, seperti berat / bias dll. Kadang-kadang yang pertama bisa jauh lebih besar dari yang terakhir.
Dawei Yang

2
Saya menguji torch.save(model, f)dan torch.save(model.state_dict(), f). File yang disimpan memiliki ukuran yang sama. Sekarang saya bingung. Juga, saya menemukan menggunakan acar untuk menyimpan model.state_dict () sangat lambat. Saya pikir cara terbaik adalah menggunakan torch.save(model.state_dict(), f)karena Anda menangani pembuatan model, dan obor menangani pemuatan bobot model, sehingga menghilangkan kemungkinan masalah. Referensi: mendiskusikan.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

Sepertinya PyTorch telah membahas ini sedikit lebih eksplisit di bagian tutorialnya — ada banyak info bagus di sana yang tidak tercantum dalam jawaban di sini, termasuk menyimpan lebih dari satu model sekaligus dan model awal yang hangat.
whlteXbread

apa yang salah dengan menggunakan pickle?
Charlie Parker

1
@CharlieParker torch.save didasarkan pada acar. Berikut ini adalah dari tutorial yang ditautkan di atas: "[torch.save] akan menyimpan seluruh modul menggunakan modul acar Python. Kerugian dari pendekatan ini adalah bahwa data berseri terikat dengan kelas-kelas tertentu dan struktur direktori yang tepat digunakan ketika model disimpan. Alasan untuk ini adalah karena acar tidak menyimpan kelas model itu sendiri. Sebaliknya, itu menyimpan path ke file yang berisi kelas, yang digunakan selama waktu buka. Karena ini, kode Anda dapat pecah dengan berbagai cara ketika digunakan dalam proyek lain atau setelah refactor. "
David Miller

Jawaban:


214

Saya telah menemukan halaman ini di repo github mereka, saya hanya akan menempelkan konten di sini.


Pendekatan yang disarankan untuk menyimpan model

Ada dua pendekatan utama untuk membuat cerita bersambung dan mengembalikan model.

Yang pertama (disarankan) hanya menyimpan dan memuat parameter model:

torch.save(the_model.state_dict(), PATH)

Kemudian nanti:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Yang kedua menyimpan dan memuat seluruh model:

torch.save(the_model, PATH)

Kemudian nanti:

the_model = torch.load(PATH)

Namun dalam kasus ini, data berseri terikat dengan kelas-kelas spesifik dan struktur direktori yang tepat digunakan, sehingga dapat pecah dengan berbagai cara ketika digunakan dalam proyek lain, atau setelah beberapa refaktor serius.


8
Menurut @smth, diskusikan.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… model memuat ulang untuk melatih model secara default. jadi perlu secara manual memanggil the_model.eval () setelah memuat, jika Anda memuatnya untuk inferensi, bukan melanjutkan pelatihan.
WillZ

metode kedua memberikan kesalahan stackoverflow.com/questions/53798009/... pada windows 10. tidak dapat menyelesaikannya
Gulzar

Apakah ada opsi untuk menyimpan tanpa perlu akses untuk kelas model?
Michael D

Dengan pendekatan itu, bagaimana Anda melacak * args dan ** kwargs yang perlu Anda lewati untuk load case?
Mariano Kamp

apa yang salah dengan menggunakan pickle?
Charlie Parker

144

Itu tergantung pada apa yang ingin Anda lakukan.

Kasus # 1: Simpan model untuk menggunakannya sendiri untuk inferensi : Anda menyimpan model, Anda mengembalikannya, dan kemudian Anda mengubah model ke mode evaluasi. Ini dilakukan karena Anda biasanya memiliki BatchNormdan Dropoutlapisan yang secara default dalam mode kereta di konstruksi:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Kasus # 2: Simpan model untuk melanjutkan pelatihan nanti : Jika Anda perlu terus melatih model yang akan Anda simpan, Anda perlu menyimpan lebih dari sekadar model. Anda juga perlu menyimpan status pengoptimal, zaman, skor, dll. Anda akan melakukannya seperti ini:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Untuk melanjutkan pelatihan Anda akan melakukan hal-hal seperti:, state = torch.load(filepath)dan kemudian, untuk mengembalikan keadaan setiap objek individu, sesuatu seperti ini:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Karena Anda melanjutkan pelatihan, JANGAN menelepon model.eval()begitu Anda mengembalikan negara saat memuat.

Kasus # 3: Model yang akan digunakan oleh orang lain tanpa akses ke kode Anda : Di Tensorflow Anda dapat membuat .pbfile yang mendefinisikan arsitektur dan bobot model. Ini sangat berguna, khususnya saat menggunakan Tensorflow serve. Cara yang setara untuk melakukan ini di Pytorch adalah:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Cara ini masih belum menjadi bukti dan karena pytorch masih mengalami banyak perubahan, saya tidak akan merekomendasikannya.


1
Apakah ada file yang direkomendasikan berakhiran untuk 3 kasus? Atau apakah selalu .pth?
Verena Haunschmid

1
Dalam Kasus # 3 torch.loadmengembalikan hanya sebuah OrderedDict. Bagaimana Anda mendapatkan model untuk membuat prediksi?
Alber8295

Hai, Boleh saya tahu bagaimana melakukan "Kasus # 2: Simpan model untuk melanjutkan pelatihan nanti"? Saya berhasil memuat pos pemeriksaan ke model, maka saya tidak dapat menjalankan atau melanjutkan untuk melatih model seperti "model.to (perangkat) model = train_model_epoch (model, kriteria, pengoptimal, sched, zaman)"
dnez

1
Hai, untuk kasus satu yang untuk inferensi, dalam dokumen pytorch resmi mengatakan bahwa harus menyimpan state_dict optimizer untuk inferensi atau menyelesaikan pelatihan. "Saat menyimpan pos pemeriksaan umum, yang akan digunakan untuk inferensi atau melanjutkan pelatihan, Anda harus menyimpan lebih dari sekadar state_dict model. Penting juga menyimpan state_dict optimizer, karena ini berisi buffer dan parameter yang diperbarui saat model melatih . "
Mohammed Awney

1
Dalam kasus # 3, kelas model harus didefinisikan di suatu tempat.
Michael D

12

The acar alat perpustakaan Python protokol biner untuk serialisasi dan de-serialisasi objek Python.

Ketika Anda import torch(atau ketika Anda menggunakan PyTorch) itu akan import pickleuntuk Anda dan Anda tidak perlu menelepon pickle.dump()dan pickle.load()langsung, yang merupakan metode untuk menyimpan dan memuat objek.

Bahkan, torch.save()dan torch.load()akan membungkus pickle.dump()dan pickle.load()untuk Anda.

Sebuah state_dictjawaban lain yang disebutkan layak hanya beberapa catatan lagi.

Apa state_dictyang kita miliki di dalam PyTorch? Sebenarnya ada dua state_dicts.

Model PyTorch torch.nn.Modulememiliki model.parameters()panggilan untuk mendapatkan parameter yang dapat dipelajari (w dan b). Parameter yang dapat dipelajari ini, setelah ditetapkan secara acak, akan diperbarui seiring waktu seperti yang kita pelajari. Parameter yang bisa dipelajari adalah yang pertama state_dict.

Yang kedua state_dictadalah dict state optimizer. Anda ingat bahwa pengoptimal digunakan untuk meningkatkan parameter yang dapat dipelajari. Tetapi optimizer state_dictsudah diperbaiki. Tidak ada yang bisa dipelajari di sana.

Karena state_dictobjek adalah kamus Python, mereka dapat dengan mudah disimpan, diperbarui, diubah, dan dipulihkan, menambahkan banyak modularitas untuk model dan pengoptimal PyTorch.

Mari kita buat model super sederhana untuk menjelaskan ini:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Kode ini akan menampilkan yang berikut:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Perhatikan ini adalah model minimal. Anda dapat mencoba menambahkan tumpukan berurutan

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Perhatikan bahwa hanya lapisan dengan parameter yang dapat dipelajari (lapisan konvolusional, lapisan linier, dll.) Dan buffer terdaftar (lapisan batchnorm) memiliki entri dalam model state_dict.

Hal-hal yang tidak dapat dipelajari, termasuk dalam objek pengoptimal state_dict, yang berisi informasi tentang status pengoptimal, serta hyperparameter yang digunakan.

Kisah selanjutnya sama; dalam fase inferensi (ini adalah fase ketika kita menggunakan model setelah pelatihan) untuk memprediksi; kami memprediksi berdasarkan parameter yang kami pelajari. Jadi untuk kesimpulan, kita hanya perlu menyimpan parameter model.state_dict().

torch.save(model.state_dict(), filepath)

Dan untuk menggunakan model.load_state_dict nanti (torch.load (filepath)) model.eval ()

Catatan: Jangan lupa baris terakhir model.eval()ini sangat penting setelah memuat model.

Juga jangan mencoba menyimpan torch.save(model.parameters(), filepath). Itu model.parameters()hanya objek generator.

Di sisi lain, torch.save(model, filepath)menyimpan objek model itu sendiri, tetapi perlu diingat model tidak memiliki pengoptimal state_dict. Periksa jawaban luar biasa lainnya oleh @Jadiel de Armas untuk menyimpan dict state optimizer.


Meskipun ini bukan solusi yang mudah, esensi masalahnya dianalisis secara mendalam! Suara positif.
Jason Young

7

Konvensi PyTorch yang umum adalah menyimpan model menggunakan ekstensi file .pt atau .pth.

Simpan / Muat Seluruh Model, Simpan:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Beban:

Kelas model harus didefinisikan di suatu tempat

model = torch.load(PATH)
model.eval()

4

Jika Anda ingin menyimpan model dan ingin melanjutkan pelatihan nanti:

GPU tunggal: Simpan:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Beban:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Banyak GPU: Simpan

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Beban:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.