当前位置:网站首页>pytorch模型
pytorch模型
2022-06-28 12:51:00 【Gu_NN】
目录
模型基本定义方法
pytorch中有提供nn.Sequential()、nn.ModuleList()以及nn.ModuleDict()用于集成多个Module,完成模型搭建。其异同如下:
| Sequential() | ModuleList() /ModuleDict() |
|---|---|
| 直接搭建网络,定义顺序即为模型连接顺序 | List/Dict中元素顺序并不代表其在网络中的真实位置顺序,需要forward函数指定各个层的连接顺序 |
| 模型中间无法加入外部输入 | 模型中间需要之前层的信息的时候,比如 ResNets 中的残差计算,比较方便 |
通过nn.Sequential()
# 方法一:
import torch.nn as nn
net = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# 方法二:
import collections
net2 = nn.Sequential(collections.OrderedDict([
('fc1', nn.Linear(784, 256)),
('relu1', nn.ReLU()),
('fc2', nn.Linear(256, 10))
]))
通过nn.ModuleList()/nn.ModuleDict()
# List
class model(nn.Module):
def __init__(self):
super().__init__()
self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU(),nn.Linear(256, 10)])
def forward(self, x):
for layer in self.modulelist:
x = layer(x)
return x
# Dict
class model(nn.Module):
def __init__(self):
super().__init__()
self.moduledict = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
'output':nn.Linear(256, 10)
})
def forward(self, x):
for layer in self.moduledict:
x = layer(x)
return x
复杂模型搭建方法
对于大型复杂模型,可以先将模型分块,然后在进行模型搭建。以U-Net模型为例。
上图为U-Net网络结构,可以分为以下四个模块:
- 每个子块内部的两次卷积(Double Convolution)
- 左侧模型块之间的下采样连接,即最大池化(Max pooling)
- 右侧模型块之间的上采样连接(Up sampling)
- 输出层的处理
模块构建
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
模型组装
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
既有模型修改
替换某layer
import torchvision.models as models
net = models.resnet50()
print(net)
# 替换其中fc层
from collections import OrderedDict
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
('relu1', nn.ReLU()),
('dropout1',nn.Dropout(0.5)),
('fc2', nn.Linear(128, 10)),
('output', nn.Softmax(dim=1))
]))
#此句直接用定义的classifier替换原来fc层
net.fc = classifier
增加输入变量
#定义模型修改
class Model(nn.Module):
def __init__(self, net):
super(Model, self).__init__()
# 原网络结构
self.net = net
# 先将2048维的tensor通过激活函数层
self.relu = nn.ReLU()
# dropout层
self.dropout = nn.Dropout(0.5)
# 全连接层映射到指定的输出维度10
self.fc_add = nn.Linear(1001, 10, bias=True)
self.output = nn.Softmax(dim=1)
def forward(self, x, add_variable):
x = self.net(x)
#在激活层、dropout层后与外部输入变量拼接
x = torch.cat((self.dropout(self.relu(x)), add_variable.unsqueeze(1)),1) #unsqueeze操作是为了和net输出的tensor保持维度一致,常用于add_variable是单一数值 (scalar) 的情况
x = self.fc_add(x)
x = self.output(x)
return x
#实例化
model = Model(net).cuda()
#训练
outputs = model(inputs, add_var)
增加输出变量
class Model(nn.Module):
def __init__(self, net):
super(Model, self).__init__()
self.net = net
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(1000, 10, bias=True)
self.output = nn.Softmax(dim=1)
def forward(self, x, add_variable):
x1000 = self.net(x)
x10 = self.dropout(self.relu(x1000))
x10 = self.fc1(x10)
x10 = self.output(x10)
return x10, x1000 #增加输出
model = Model(net).cuda()
out10, out1000 = model(inputs, add_var)
模型保存、加载
PyTorch存储模型主要采用pkl,pt,pth三种格式。
PyTorch模型主要包含两个部分:模型结构和权重。
- 模型:nn.Module的类
- 权重:字典(key是层名,value是权重向量)。
存储也可分为两种形式:
- 存储模型结构+权重
- 只存储权重
from torchvision import models
model = models.resnet152(pretrained=True)
# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict, save_dir)
当出现多GPU并行时存储读取会有单卡、多卡情况,而多卡存储过程名称比单卡多module字段,故当多卡存储时,模型加载会复杂一些。
保存
单卡保存
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
model = models.resnet152(pretrained=True)
model.cuda()
# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict(), save_dir)
多卡保存
用nn.DataParallel函数进行分布式训练设置即可
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' #这里替换成希望使用的GPU编号
model = models.resnet152(pretrained=True)
model = nn.DataParallel(model).cuda()
# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict(), save_dir)
加载
单卡加载
- 单卡保存模型
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
# 读取整个模型
loaded_model = torch.load(save_dir)
loaded_model.cuda()
# 读取模型权重
loaded_dict = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_dict
loaded_model.cuda()
- 多卡保存模型
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
# 读取整个模型
loaded_model = torch.load(save_dir)
loaded_model = loaded_model.module #不同之处
# 读取模型权重(推荐)
loaded_dict = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model = nn.DataParallel(loaded_model).cuda() #不同之处
loaded_model.state_dict = loaded_dict
# 读取模型权重(其他方法1)
from collections import OrderedDict
loaded_dict = torch.load(save_dir)
# 去除module字段
new_state_dict = OrderedDict()
for k, v in loaded_dict.items():
name = k[7:] # module字段在最前面,从第7个字符开始就可以去掉module
new_state_dict[name] = v #新字典的key值对应的value一一对应
# 其他与单卡保存模型一致
loaded_model = models.resnet152()
loaded_model.state_dict = new_state_dict
loaded_model = loaded_model.cuda()
# 读取模型权重(其他方法2)
loaded_model = models.resnet152()
loaded_dict = torch.load(save_dir)
loaded_model.load_state_dict({
k.replace('module.', ''): v for k, v in loaded_dict.items()})
loaded_model = loaded_model.cuda()
多卡加载
- 单卡存储模型
用nn.DataParallel函数进行分布式训练设置即可
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2' #这里替换成希望使用的GPU编号
# 读取整个模型
loaded_model = torch.load(save_dir)
loaded_model = nn.DataParallel(loaded_model).cuda()#不同处
# 读取模型权重
loaded_dict = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_dict
loaded_model = nn.DataParallel(loaded_model).cuda()#不同处
- 多卡存储模型
建议仅存储权重,与单卡无异。若只有整个模型,则需要如下代码:
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' #这里替换成希望使用的GPU编号
loaded_whole_model = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_whole_model.state_dict
loaded_model = nn.DataParallel(loaded_model).cuda()
参考
边栏推荐
- JS duration and asynchronous function promise
- [cloud native] can self-service reports and Bi do so many things?
- Matplotlib_ Study01
- 易观分析《2022年中国银行业隐私计算平台供应商实力矩阵分析》研究报告正式启动
- 证券账户开户哪家的费率低 怎么办理开户最安全
- Hundreds of lines of code to implement a JSON parser
- 结构光之相移法+多频外差的数学原理推导
- ASP. NET CORE Study02
- Ugui force refresh of layout components
- Deep understanding of Bayes theorem
猜你喜欢

The Research Report of Analysys' 2022 China Banking privacy computing platform supplier strength matrix analysis' was officially launched

ASP.NET CORE Study02

Ugui force refresh of layout components

I ² C. SMBus, pmbus relationships

Copying open source for basic software is not advisable. Self reliance is the right way

Here comes Wi Fi 7. How strong is it?

从SimpleKV到Redis

My NVIDIA developer tour -jetson nano 2GB teaches you how to train models (complete model training routines)

中二青年付杰的逆袭故事:从二本生到 ICLR 杰出论文奖,我用了20年
![[unity Editor Extension practice] dynamically generate UI code using TXT template](/img/20/1042829c3880039c528c63d0aa472d.png)
[unity Editor Extension practice] dynamically generate UI code using TXT template
随机推荐
证券账户开户哪家的费率低 怎么办理开户最安全
高考失利進哈工大,畢業卻留校要當“探索者”,丁效:科研就是厚積薄發
I²C、SMBus、PMBus关系
Wi-Fi 7 来啦,它到底有多强?
数字孪生能源系统,打造低碳时代“透视”眼
Hundreds of lines of code to implement a JSON parser
10万美元AI竞赛:寻找大模型做得“更烂”的任务
如何在熊市中寻找机会?
基础软件照搬开源不可取,自力更生才是正途
Ipetronik data acquisition equipment and softing q-vision software are committed to ADAS test scheme
Login interface accesses and clears the token
企业源代码保密方案分享
Performance test-01-introduction
中二青年付杰的逆袭故事:从二本生到 ICLR 杰出论文奖,我用了20年
go template with... End traversal usage
How to install SSL certificates in Microsoft Exchange 2010
centos6.5 php+mysql mysql库找不到
Bytev builds a dynamic digital twin network security platform -- helping network security development
【云原生】自助报表和BI能做这么多事?
Customize MySQL connection pool