当前位置:网站首页>Complete example of pytorch model saving +does pytorch model saving only save trainable parameters? Yes (+ solution)
Complete example of pytorch model saving +does pytorch model saving only save trainable parameters? Yes (+ solution)
2022-07-02 19:55:00 【FakeOccupational】
The test uses a liner model, There are more questions .pytorch Does model saving only save trainable parameters ?
save Model
# Import package
import glob
import os
import torch
import matplotlib.pyplot as plt
import random # For data iterators to generate random data
# Generate data set x1 Category 0,x2 Category 1
n_data = torch.ones(50, 2) # The basic form of data
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # type 0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # type 1 shape=(50, 1)
# Be careful x, y The data form of data must be like the following (torch.cat Is consolidated data )
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # Dataset Visualization plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # data fetch : def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # The reading order of samples is random
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # The last time may be less than one batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
""" Calculate the correct number of predictions ."""
cmp = y_hat.type(y.dtype) > 0.5 # Greater than 0.5 Category 1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # The number of iterations
batch_size = 10 # Batch size
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
l.backward(retain_graph=True)
optimizer.step()
print(l)
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
load Model
# Import package
import glob
import os
import torch
import matplotlib.pyplot as plt
import random # For data iterators to generate random data
# Generate data set x1 Category 0,x2 Category 1
n_data = torch.ones(50, 2) # The basic form of data
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # type 0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # type 1 shape=(50, 1)
# Be careful x, y The data form of data must be like the following (torch.cat Is consolidated data )
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # Dataset Visualization plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # data fetch : def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # The reading order of samples is random
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # The last time may be less than one batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
""" Calculate the correct number of predictions ."""
cmp = y_hat.type(y.dtype) > 0.5 # Greater than 0.5 Category 1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # The number of iterations
batch_size = 10 # Batch size
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print('pretrained models loaded!')
pytorch Does model saving only save trainable parameters ? yes
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)

Solution
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l It's about small batches X and y The loss of
l.backward(retain_graph=True)
optimizer.step()
print(l)
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
Reference and more
PyTorch DataLoader Of bug : Random mask Or random selection of data bug
边栏推荐
- GCC: Graph Contrastive Coding for Graph Neural NetworkPre-Training
- Conscience summary! Jupyter notebook from Xiaobai to master, the nanny tutorial is coming!
- RPD出品:Superpower Squad 保姆级攻略
- c语言里怎么设立优先级,细说C语言优先级
- [Chongqing Guangdong education] reference materials for labor education of college students in Nanjing University
- HDL design peripheral tools to reduce errors and help you take off!
- Zabbix5 client installation and configuration
- R语言使用econocharts包创建微观经济或宏观经济图、indifference函数可视化无差异曲线(indifference curve)
- 中缀表达式转换为后缀表达式(C语言代码+详解)
- For (Auto A: b) and for (Auto & A: b) usage
猜你喜欢

蓝牙芯片ble是什么,以及该如何选型,后续技术发展的路径是什么

AcWing 1126. Minimum cost solution (shortest path Dijkstra)

AcWing 1126. 最小花费 题解(最短路—dijkstra)

Implementation of online shopping mall system based on SSM

SQLite 3.39.0 release supports right external connection and all external connection

嵌入式(PLD) 系列,EPF10K50RC240-3N 可编程逻辑器件

Istio部署:快速上手微服务,

勵志!大凉山小夥全獎直博!論文致謝看哭網友

Registration opportunity of autowiredannotationbeanpostprocessor under annotation development mode

B端电商-订单逆向流程
随机推荐
AcWing 903. Expensive bride price solution (the shortest path - building map, Dijkstra)
编写完10万行代码,我发了篇长文吐槽Rust
What is the Bluetooth chip ble, how to select it, and what is the path of subsequent technology development
KT148A语音芯片ic的硬件设计注意事项
Registration opportunity of autowiredannotationbeanpostprocessor under annotation development mode
Cuckoo filter
自动生成VGG图像注释文件
为什么我对流程情有独钟?
KT148A语音芯片ic的软件参考代码C语言,一线串口
Istio deployment: quickly start microservices,
R language uses econcharts package to create microeconomic or macroeconomic maps, and indifference function to visualize indifference curve
KT148A语音芯片使用说明、硬件、以及协议、以及常见问题,和参考代码
Design and implementation of ks004 based on SSH address book system
Build a master-slave mode cluster redis
AcWing 343. Sorting problem solution (Floyd property realizes transitive closure)
Shardingsphere jdbc5.1.2 about select last_ INSERT_ ID () I found that there was still a routing problem
AcWing 383. Sightseeing problem solution (shortest circuit)
CS5268完美代替AG9321MCQ Typec多合一扩展坞方案
八年测开经验,面试28K公司后,吐血整理出高频面试题和答案
AcWing 1128. Messenger solution (shortest path Floyd)