当前位置:网站首页>Pytorch save and load model
Pytorch save and load model
2022-07-23 09:40:00 【Mick..】
Model saving and loading
1 Only save and load model parameters
torch.save(model.state_dict(), PATH) ### Save the parameters of the model to this address , Suffix named pt
model = model(*args, **kwargs) ### Defining models
model.load_state_dict(torch.load(PATH)) ## Import model parameters
2 Save and load the entire model
torch.save(model,path)
model=torch.load(path)This method can directly save the whole model , There is no need to redefine the model when applying .
Define network structure
The simplest network structure is defined here . Full connection layer of two layers
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1=nn.Linear(1,3) ### Linear layer
self.layer2=nn.Linear(3,1)
def forward(self,x):
x=self.layer1(x)
x=torch.relu(x) ###relu Activation function
x=self.layer2(x)
return x
Training neural network
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
epoches=2000
# Learning rate is defined as 0.01
learning_rate=0.01
# Create a model
model=Net()
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.MSELoss() # Define the loss function
# Use the optimizer to update the network weights ,lr For learning rate ,
for i in range(epoch): # Set training epoch Time
model.train() # Set the state of the model to train
for j in Sample: # Traverse each sample
optimizer.zero_grad() # Gradient cleanup , Prepare for this gradient calculation
output = model(j)
loss = criterion(output, target)
loss.backward() ### The average recorded here loss
optimizer.step() # Update network weights
if (epoch+1) % 10==0: ## Print the current status every ten times
print("Epoch {} / {},loss {:.4f}".format(epoch+1,num_epoches,loss.item()))
Pytorch Model preservation
##torch.save() You can save dictionary type data
save_checkpoint({'loss': i, 'state_dict': model.state_dict()},dir)
def save_checkpoint(state, dic): #state Is the weight and state of the model dic Is the directory where the model is saved
if not os.path.exists(dir):
os.makedirs(directory)
fileName = directory + 'last.pth'
torch.save(state,fileName)# Use torch.save Function directly saves the trained model
边栏推荐
猜你喜欢

Teach you how to set up Alibaba cloud DDNS in Qunhui

【Jailhouse 文章】Virtualization over Multiprocessor System-on-Chip an Enabling Paradigm for...

判断两个类型是否相同

如何确定一个软件的测试结束点

MATLAB之优劣解距离法Topsis模型

Double disk: what is a b+ tree? Do you know how b+ trees build ordered tables? What are the characteristics of b+ tree

第三方依赖库 AG Grid调研分析

Salary increase artifact

Wallys/DR4019S/IPQ4019/11ABGN/802.11AC/high power

【无标题】
随机推荐
PNA specification information | soybean peroxidase labeled PNA (peptide nucleic acid, PNA)
我想在挖财学习理财开户安全吗?
C语言课设----个人信息管理系统(包含学生成绩和消费记录)
多肽修饰PNA肽核酸Bz-D-Phe-Val-Arg-pNA|L-Phe-Val-Arg-pNA
Teach you how to set up Alibaba cloud DDNS in Qunhui
力扣(LeetCode)203. 移除链表元素(2022.07.22)
canal 第五篇
Double disk: what is a b+ tree? Do you know how b+ trees build ordered tables? What are the characteristics of b+ tree
Canal Chapter 8
Accumulation of FPGA errors
1059 Prime Factors
pna肽核酸定制服务|肽核酸钳制-PCR(PNA-PCR)|cGAPPNA多价肽核酸配体
毕业1年,放弃实习机会,在家自学软件测试,同学实习刚结束,我已成月薪12k测试工程师
Installation, configuration and use of sentry
Leetcode 110. balanced binary tree
结合实战,浅析GB/T 28181(二)——设备目录同步
600 English words that programmers must master
Wallys/PD-60 802.3AT Input Output802.3AT/AT 85% Efficiency 10/100/1000M GE Surge Protection
判断两个类型是否相同
codeforces每日5题(均1500)-第二十三天