当前位置:网站首页>pytorch 模型保存的完整例子+pytorch 模型保存只保存可訓練參數嗎?是(+解决方案)
pytorch 模型保存的完整例子+pytorch 模型保存只保存可訓練參數嗎?是(+解决方案)
2022-07-02 19:54:00 【FakeOccupational】
測試使用的是一個liner model,還有更多的問題。pytorch 模型保存只保存可訓練參數嗎?
save模型
# 導入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用於數據迭代器生成隨機數據
# 生成數據集 x1類別0,x2類別1
n_data = torch.ones(50, 2) # 數據的基本形態
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 類型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 類型1 shape=(50, 1)
# 注意 x, y 數據的數據形式一定要像下面一樣(torch.cat是合並數據)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 數據集可視化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 數據讀取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 樣本的讀取順序是隨機的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最後一次可能不足一個batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
"""計算預測正確的數量。"""
cmp = y_hat.type(y.dtype) > 0.5 # 大於0.5類別1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次數
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l是有關小批量X和y的損失
l.backward(retain_graph=True)
optimizer.step()
print(l)
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
load模型
# 導入包
import glob
import os
import torch
import matplotlib.pyplot as plt
import random #用於數據迭代器生成隨機數據
# 生成數據集 x1類別0,x2類別1
n_data = torch.ones(50, 2) # 數據的基本形態
x1 = torch.normal(2 * n_data, 1) # shape=(50, 2)
y1 = torch.zeros(50) # 類型0 shape=(50, 1)
x2 = torch.normal(-2 * n_data, 1) # shape=(50, 2)
y2 = torch.ones(50) # 類型1 shape=(50, 1)
# 注意 x, y 數據的數據形式一定要像下面一樣(torch.cat是合並數據)
x = torch.cat((x1, x2), 0).type(torch.FloatTensor) y = torch.cat((y1, y2), 0).type(torch.FloatTensor) # 數據集可視化 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 數據讀取: def data_iter(batch_size, x, y): num_examples = len(x) indices = list(range(num_examples))
random.shuffle(indices) # 樣本的讀取順序是隨機的
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最後一次可能不足一個batch
yield x.index_select(0, j), y.index_select(0, j)
#############################################################################################################
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {
}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + 'h' + str(epoch))
print('models {} save successfully!'.format(model_path + 'hahaha' + str(epoch)))
################################################################################################################
import torch.nn as nn
import torch.optim as optim
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
def forward(self, x):
return self.net(x)
def loss(y_hat, y):
return (y_hat - y.view(y_hat.size())) ** 2 / 2
def accuracy(y_hat, y): #@save
"""計算預測正確的數量。"""
cmp = y_hat.type(y.dtype) > 0.5 # 大於0.5類別1
result=cmp.type(y.dtype)
acc = 1-float(((result-y).sum())/ len(y))
return acc;
lr = 0.03
num_epochs = 3 # 迭代次數
batch_size = 10 # 批量大小
model = net()
params = list(model.parameters())
optimizer = torch.optim.Adam(params, 1e-4)
# for epoch in range(num_epochs):
# for X, y_train in data_iter(batch_size, x, y):
# optimizer.zero_grad()
# l = loss(model(X), y_train).sum() # l是有關小批量X和y的損失
# l.backward(retain_graph=True)
# optimizer.step()
# print(l)
# saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
model_state_dict, optimizer_state_dict = loader("h1")
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
print('pretrained models loaded!')
pytorch 模型保存只保存可訓練參數嗎?是
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
self.notrain= torch.rand((64, 64), dtype=torch.float)
def forward(self, x):
return self.net(x)

解决方案
class net(nn.Module):
def __init__(self, **kwargs):
super(net, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(2, 1), nn.ReLU())
# self.notrain = torch.rand((64, 64), dtype=torch.float)
self.notrain = torch.nn.Parameter(torch.ones(64, 64))
def forward(self, x):
return self.net(x)
for epoch in range(num_epochs):
for X, y_train in data_iter(batch_size, x, y):
optimizer.zero_grad()
l = loss(model(X), y_train).sum() # l是有關小批量X和y的損失
l.backward(retain_graph=True)
optimizer.step()
print(l)
model.notrain.data = model.notrain.data+2
saver(model.state_dict(), optimizer.state_dict(), "./", epoch + 1, max_to_save=100)
參考與更多
边栏推荐
- Refactoring: improving the design of existing code (Part 2)
- Start practicing calligraphy
- NMF-matlab
- For (Auto A: b) and for (Auto & A: b) usage
- Kt148a voice chip instructions, hardware, protocols, common problems, and reference codes
- AcWing 1126. 最小花费 题解(最短路—dijkstra)
- Automatically generate VGg image annotation file
- 台湾SSS鑫创SSS1700替代Cmedia CM6533 24bit 96KHZ USB音频编解码芯片
- 451 implementation of memcpy, memmove and memset
- What is the Bluetooth chip ble, how to select it, and what is the path of subsequent technology development
猜你喜欢

八年测开经验,面试28K公司后,吐血整理出高频面试题和答案

浏览器缓存机制概述

API文档工具knife4j使用详解

数据湖(十二):Spark3.1.2与Iceberg0.12.1整合
In depth understanding of modern web browsers (I)

Yes, that's it!

Kt148a voice chip IC software reference code c language, first-line serial port

接口测试到底怎么做?看完这篇文章就能清晰明了

Introduction to program ape (XII) -- data storage

HDL design peripheral tools to reduce errors and help you take off!
随机推荐
Think about the huge changes caused by variables
Zabbix5 client installation and configuration
Use IDM to download Baidu online disk files (useful for personal testing) [easy to understand]
Design and implementation of ks004 based on SSH address book system
KT148A语音芯片ic的硬件设计注意事项
AcWing 1131. 拯救大兵瑞恩 题解(最短路)
KS004 基于SSH通讯录系统设计与实现
Automatic reading of simple books
C language linked list -- to be added
蓝牙芯片ble是什么,以及该如何选型,后续技术发展的路径是什么
RPD出品:Superpower Squad 保姆级攻略
勵志!大凉山小夥全獎直博!論文致謝看哭網友
中缀表达式转换为后缀表达式(C语言代码+详解)
VBScript详解(一)
AcWing 1134. Shortest circuit counting problem solution (shortest circuit)
自動生成VGG圖像注釋文件
Correspondence between pytoch version, CUDA version and graphics card driver version
Detailed explanation of VBScript (I)
Postman download and installation
AcWing 342. Road and route problem solving (shortest path, topological sorting)