当前位置:网站首页>Pytorch - storage and loading model
Pytorch - storage and loading model
2022-07-28 07:14:00 【SpikeKing】
Interview questions :
PyTorch Of state_dict What does it contain ?
PyTorch There are several ways to save models ,checkpoint How is it different from other ways , What do you usually keep ?
SAVING AND LOADING MODELS FOR INFERENCE IN PYTORCH
Two ways of preservation :
- state_dict,torch.nn.modules.module,Module class , Is the parent of multiple classes , Example layer 、 The optimizer etc.
- state_dict function , Storage parameters and buffers, for example , The normalized value of the batch is buffers
- All models
Net Inherited from Module,__init__ Initialization layer ,forward Connect layers , Input x, Instantiation net = Net()
Call optimizer optim.SGD, The first 1 The first parameter is the parameter of the model ,net.parameters() function , Include current and child module Parameters of
torch.save(net.state_dict(), PATH), With name 、epoch、train loss、eval loss, Save only parameters , There is no structure to save the model ( chart )
about Net Instantiation , call load_state_dict() function , hold dict Import in , Use torch.load(PATH)
preservation :
- save -> state_dict
- load -> load_state_dict
call eval(), take training Set to False, The gradient will not be saved , Will also require_grad Set to false, Use reasoning patterns at the same time , for example Dropout、BN layer
torch.save(net, PATH), Keep the graph structure and parameters directly , Just call it directly ,torch.load(PATH)
import torch
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Specify a path
PATH = "state_dict_model.pt"
# Save
torch.save(net.state_dict(), PATH)
# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
# Specify a path
PATH = "entire_model.pt"
# Save
torch.save(net, PATH)
# Load
model = torch.load(PATH)
model.eval()
SAVING AND LOADING A GENERAL CHECKPOINT IN PYTORCH
Save and load general checkpoint
checkpoint preservation , call torch.save(), When epoch % 5 == 0 when , call torch.save(dict, PATH)
Common parameters :epoch、model_state_dict、optimizer_state_dict、loss, During training , A very important amount of information
torch.load(PATH) load checkpoint, To assign a value
- model.load_state_dict()
- optimizer.load_state_dict()
- epoch
- loss
During training , Try to press checkpoint Way,
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
model = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
SAVING AND LOADING MULTIPLE MODELS IN ONE FILE USING PYTORCH
Save and load multiple models in one file
With saving a single model checkpoint similar , Put the parameters of multiple models into a big dictionary , Load together again , To deal with
PATH = "model.pt"
torch.save({
'modelA_state_dict': netA.state_dict(),
'modelB_state_dict': netB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
}, PATH)
modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
Use docker Container creation environment :

seaborn:https://seaborn.pydata.org/
Common software :

Free of charge GPU resources :Colaboratory
边栏推荐
- MOOC Weng Kai C language fourth week: further judgment and circulation: 3. Multiple branches 4. Examples of circulation 5. Common errors in judgment and circulation
- Gd32f407 porting freertos+lwip
- Install Nessus under Kali
- Joern的代码使用-devign
- 读取xml文件里switch节点的IP和设备信息,ping设备,异常显示在列表里
- Read the IP and device information of the switch node in the XML file, Ping the device, and the exception is displayed in the list
- VLAN的配置
- MOOC Weng Kai C language week 6: arrays and functions: 1. Arrays 2. Definition and use of functions 3. Parameters and variables of functions 4. Two dimensional arrays
- bond模式配置
- Easypoi one to many, merge cells, and adapt the row height according to the content
猜你喜欢

Nrf51822 review summary

Result fill in the blank (dfs*c language)

shell---函数

Generate create table creation SQL statement according to excel

Understanding of maximum likelihood estimation, gradient descent, linear regression and logistic regression

MOOC翁恺C语言 第六周:数组与函数:1.数组2.函数的定义与使用3.函数的参数和变量4.二维数组

Sysevr environment configuration: joern-0.3.1, neo4j-2.1.5, py2neo2.0

Easypoi export table with echars chart

VLAN configuration

DNS域名解析
随机推荐
Detailed explanation of active scanning technology nmap
Review of C language (byte alignment)
Small turtle C (Chapter 6 arrays 1 and 2)
Reptile learning summary
Raspberry pie serial port
Joern的代码使用-devign
Standard C language learning summary 5
Install Nessus under Kali
Shell --- conditional statement practice
Standard C language learning summary 3
[learning records of erudite Valley] Super summary, attentive sharing | collection
Softmax multi classification gradient derivation
MOOC翁恺C语言 第六周:数组与函数:1.数组2.函数的定义与使用3.函数的参数和变量4.二维数组
Animation animation realizes the crossing (click) pause
Starting point Chinese website font anti crawling technology web page can display numbers and letters, and the web page code is garbled or blank
远程访问云服务器上Neo4j等服务的本地网址
Shell--第一天作业
关于正则的教程
shell---条件语句练习
MySQL build database Series (I) -- download MySQL