当前位置:网站首页>Complete example of pytorch model saving +does pytorch model saving only save trainable parameters? Yes (+ solution)
Complete example of pytorch model saving +does pytorch model saving only save trainable parameters? Yes (+ solution)
2022-07-02 19:55:00 【FakeOccupational】
The test uses a liner model, There are more questions .pytorch Does model saving only save trainable parameters ?
save Model
# Import package
import glob
import os
import torch
import matplotlib.pyplot as plt
import random # For data iterators to generate random data
# Generate data set x1 Category 0,x2 Category 1
n_data = torch.ones(50, 2) # The basic form of data
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # type 0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # type 1 shape=(50, 1)
# Be careful x, y The data form of data must be like the following (torch.cat Is consolidated data )
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # Dataset Visualization plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # data fetch : def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # The reading order of samples is random
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # The last time may be less than one batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
""" Calculate the correct number of predictions ."""
cmp = y_hat.type(y.dtype) > 0.5 # Greater than 0.5 Category 1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # The number of iterations
batch_size = 10 # Batch size
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
l.backward(retain_graph=True)
optimizer.step()
print(l)
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
load Model
# Import package
import glob
import os
import torch
import matplotlib.pyplot as plt
import random # For data iterators to generate random data
# Generate data set x1 Category 0,x2 Category 1
n_data = torch.ones(50, 2) # The basic form of data
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # type 0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # type 1 shape=(50, 1)
# Be careful x, y The data form of data must be like the following (torch.cat Is consolidated data )
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # Dataset Visualization plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # data fetch : def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # The reading order of samples is random
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # The last time may be less than one batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
""" Calculate the correct number of predictions ."""
cmp = y_hat.type(y.dtype) > 0.5 # Greater than 0.5 Category 1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # The number of iterations
batch_size = 10 # Batch size
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print('pretrained models loaded!')
pytorch Does model saving only save trainable parameters ? yes
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)
Solution
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
l.backward(retain_graph=True)
optimizer.step()
print(l)
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
Reference and more
PyTorch DataLoader Of bug : Random mask Or random selection of data bug
边栏推荐
- Educational codeforces round 129 (rated for Div. 2) supplementary problem solution
- Py's interpret: a detailed introduction to interpret, installation, and case application
- 台湾SSS鑫创SSS1700替代Cmedia CM6533 24bit 96KHZ USB音频编解码芯片
- burp 安装 license key not recognized
- KT148A语音芯片ic的用户端自己更换语音的方法,上位机
- 笔记本安装TIA博途V17后出现蓝屏的解决办法
- Use IDM to download Baidu online disk files (useful for personal testing) [easy to understand]
- Solution: vs2017 cannot open the source file stdio h main. H header document [easy to understand]
- [NLP] a detailed generative text Abstract classic paper pointer generator
- esp32c3 crash分析
猜你喜欢
GCC: Graph Contrastive Coding for Graph Neural NetworkPre-Training
【实习】解决请求参数过长问题
Registration opportunity of autowiredannotationbeanpostprocessor in XML development mode
pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
rxjs Observable 自定义 Operator 的开发技巧
自動生成VGG圖像注釋文件
AcWing 340. Solution to communication line problem (binary + double ended queue BFS for the shortest circuit)
CRM客户关系管理系统
Self-Improvement! Daliangshan boys all award Zhibo! Thank you for your paper
Motivation! Big Liangshan boy a remporté le prix Zhibo! Un article de remerciement pour les internautes qui pleurent
随机推荐
勵志!大凉山小夥全獎直博!論文致謝看哭網友
pxe装机「建议收藏」
AcWing 341. 最优贸易 题解 (最短路、dp)
面试经验总结,为你的offer保驾护航,满满的知识点
Introduction to program ape (XII) -- data storage
攻防世界pwn题:Recho
JASMINER X4 1U deep disassembly reveals the secret behind high efficiency and power saving
Pytorch版本、CUDA版本与显卡驱动版本的对应关系
中缀表达式转换为后缀表达式(C语言代码+详解)
Conscience summary! Jupyter notebook from Xiaobai to master, the nanny tutorial is coming!
KT148A语音芯片ic的硬件设计注意事项
简书自动阅读
编写完10万行代码,我发了篇长文吐槽Rust
450 Shenxin Mianjing 1
AcWing 383. Sightseeing problem solution (shortest circuit)
HDL design peripheral tools to reduce errors and help you take off!
有时候只查询一行语句,执行也慢
Codeworks round 802 (Div. 2) pure supplementary questions
Use IDM to download Baidu online disk files (useful for personal testing) [easy to understand]
为什么我对流程情有独钟?