当前位置:网站首页>pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
pytorch 模型保存的完整例子+pytorch 模型保存只保存可训练参数吗?是(+解决方案)
2022-07-02 19:02:00 【FakeOccupational】
测试使用的是一个liner model,还有更多的问题。pytorch 模型保存只保存可训练参数吗?
save模型
# 导入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用于数据迭代器生成随机数据
# 生成数据集 x1类别0,x2类别1
n_data = torch.ones(50, 2) # 数据的基本形态
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 类型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 类型1 shape=(50, 1)
# 注意 x, y 数据的数据形式一定要像下面一样(torch.cat是合并数据)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 数据集可视化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 数据读取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能不足一个batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
"""计算预测正确的数量。"""
cmp = y_hat.type(y.dtype) > 0.5 # 大于0.5类别1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
l.backward(retain_graph=True)
optimizer.step()
print(l)
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
load模型
# 导入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用于数据迭代器生成随机数据
# 生成数据集 x1类别0,x2类别1
n_data = torch.ones(50, 2) # 数据的基本形态
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 类型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 类型1 shape=(50, 1)
# 注意 x, y 数据的数据形式一定要像下面一样(torch.cat是合并数据)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 数据集可视化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 数据读取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能不足一个batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
"""计算预测正确的数量。"""
cmp = y_hat.type(y.dtype) > 0.5 # 大于0.5类别1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次数
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print('pretrained models loaded!')
pytorch 模型保存只保存可训练参数吗?是
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)
解决方案
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l是有关小批量X和y的损失
l.backward(retain_graph=True)
optimizer.step()
print(l)
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
参考与更多
边栏推荐
- What are the benefits of multi terminal applet development? Covering Baidu applet, Tiktok applet, wechat applet development, and seizing the multi platform traffic dividend
- 450 Shenxin Mianjing 1
- Notes on hardware design of kt148a voice chip IC
- 多端小程序开发有什么好处?覆盖百度小程序抖音小程序微信小程序开发,抢占多平台流量红利
- Design and implementation of ks004 based on SSH address book system
- At compilation environment setup -win
- 蓝牙芯片ble是什么,以及该如何选型,后续技术发展的路径是什么
- Build a master-slave mode cluster redis
- How to avoid duplicate data in gaobingfa?
- 勵志!大凉山小夥全獎直博!論文致謝看哭網友
猜你喜欢
KS004 基于SSH通讯录系统设计与实现
嵌入式(PLD) 系列,EPF10K50RC240-3N 可编程逻辑器件
为什么我对流程情有独钟?
Introduction to program ape (XII) -- data storage
Windows2008r2 installing php7.4.30 requires localsystem to start the application pool, otherwise 500 error fastcgi process exits unexpectedly
定了,就是它!
Automatically generate VGg image annotation file
[daily question] 241 Design priorities for operational expressions
Zabbix5 client installation and configuration
Shardingsphere jdbc5.1.2 about select last_ INSERT_ ID () I found that there was still a routing problem
随机推荐
AcWing 340. Solution to communication line problem (binary + double ended queue BFS for the shortest circuit)
JASMINER X4 1U deep disassembly reveals the secret behind high efficiency and power saving
API文档工具knife4j使用详解
at编译环境搭建-win
功能、作用、效能、功用、效用、功效
AcWing 1127. Sweet butter solution (shortest path SPFA)
Chapter 7 - class foundation
开始练习书法
台湾SSS鑫创SSS1700替代Cmedia CM6533 24bit 96KHZ USB音频编解码芯片
AcWing 1129. 热浪 题解(最短路—spfa)
[ERP software] what are the dangers of the secondary development of ERP system?
测试人员如何做不漏测?这7点就够了
在消费互联网时代,诞生了为数不多的头部平台的话
从20s优化到500ms,我用了这三招
AcWing 1125. 牛的旅行 题解(最短路、直径)
AcWing 340. 通信线路 题解(二分+双端队列BFS求最短路)
checklistbox控件用法总结
编写完10万行代码,我发了篇长文吐槽Rust
解决方案:VS2017 无法打开源文件 stdio.h main.h 等头文件[通俗易懂]
Development skills of rxjs observable custom operator