当前位置:网站首页>pytorch模型
pytorch模型
2022-06-28 12:51:00 【Gu_NN】
目录
模型基本定义方法
pytorch中有提供nn.Sequential()、nn.ModuleList()以及nn.ModuleDict()用于集成多个Module,完成模型搭建。其异同如下:
| Sequential() | ModuleList() /ModuleDict() |
|---|---|
| 直接搭建网络,定义顺序即为模型连接顺序 | List/Dict中元素顺序并不代表其在网络中的真实位置顺序,需要forward函数指定各个层的连接顺序 |
| 模型中间无法加入外部输入 | 模型中间需要之前层的信息的时候,比如 ResNets 中的残差计算,比较方便 |
通过nn.Sequential()
# 方法一:
import torch.nn as nn
net = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# 方法二:
import collections
net2 = nn.Sequential(collections.OrderedDict([
('fc1', nn.Linear(784, 256)),
('relu1', nn.ReLU()),
('fc2', nn.Linear(256, 10))
]))
通过nn.ModuleList()/nn.ModuleDict()
# List
class model(nn.Module):
def __init__(self):
super().__init__()
self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU(),nn.Linear(256, 10)])
def forward(self, x):
for layer in self.modulelist:
x = layer(x)
return x
# Dict
class model(nn.Module):
def __init__(self):
super().__init__()
self.moduledict = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
'output':nn.Linear(256, 10)
})
def forward(self, x):
for layer in self.moduledict:
x = layer(x)
return x
复杂模型搭建方法
对于大型复杂模型,可以先将模型分块,然后在进行模型搭建。以U-Net模型为例。
上图为U-Net网络结构,可以分为以下四个模块:
- 每个子块内部的两次卷积(Double Convolution)
- 左侧模型块之间的下采样连接,即最大池化(Max pooling)
- 右侧模型块之间的上采样连接(Up sampling)
- 输出层的处理
模块构建
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
模型组装
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
既有模型修改
替换某layer
import torchvision.models as models
net = models.resnet50()
print(net)
# 替换其中fc层
from collections import OrderedDict
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
('relu1', nn.ReLU()),
('dropout1',nn.Dropout(0.5)),
('fc2', nn.Linear(128, 10)),
('output', nn.Softmax(dim=1))
]))
#此句直接用定义的classifier替换原来fc层
net.fc = classifier
增加输入变量
#定义模型修改
class Model(nn.Module):
def __init__(self, net):
super(Model, self).__init__()
# 原网络结构
self.net = net
# 先将2048维的tensor通过激活函数层
self.relu = nn.ReLU()
# dropout层
self.dropout = nn.Dropout(0.5)
# 全连接层映射到指定的输出维度10
self.fc_add = nn.Linear(1001, 10, bias=True)
self.output = nn.Softmax(dim=1)
def forward(self, x, add_variable):
x = self.net(x)
#在激活层、dropout层后与外部输入变量拼接
x = torch.cat((self.dropout(self.relu(x)), add_variable.unsqueeze(1)),1) #unsqueeze操作是为了和net输出的tensor保持维度一致,常用于add_variable是单一数值 (scalar) 的情况
x = self.fc_add(x)
x = self.output(x)
return x
#实例化
model = Model(net).cuda()
#训练
outputs = model(inputs, add_var)
增加输出变量
class Model(nn.Module):
def __init__(self, net):
super(Model, self).__init__()
self.net = net
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(1000, 10, bias=True)
self.output = nn.Softmax(dim=1)
def forward(self, x, add_variable):
x1000 = self.net(x)
x10 = self.dropout(self.relu(x1000))
x10 = self.fc1(x10)
x10 = self.output(x10)
return x10, x1000 #增加输出
model = Model(net).cuda()
out10, out1000 = model(inputs, add_var)
模型保存、加载
PyTorch存储模型主要采用pkl,pt,pth三种格式。
PyTorch模型主要包含两个部分:模型结构和权重。
- 模型:nn.Module的类
- 权重:字典(key是层名,value是权重向量)。
存储也可分为两种形式:
- 存储模型结构+权重
- 只存储权重
from torchvision import models
model = models.resnet152(pretrained=True)
# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict, save_dir)
当出现多GPU并行时存储读取会有单卡、多卡情况,而多卡存储过程名称比单卡多module字段,故当多卡存储时,模型加载会复杂一些。
保存
单卡保存
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
model = models.resnet152(pretrained=True)
model.cuda()
# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict(), save_dir)
多卡保存
用nn.DataParallel函数进行分布式训练设置即可
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' #这里替换成希望使用的GPU编号
model = models.resnet152(pretrained=True)
model = nn.DataParallel(model).cuda()
# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict(), save_dir)
加载
单卡加载
- 单卡保存模型
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
# 读取整个模型
loaded_model = torch.load(save_dir)
loaded_model.cuda()
# 读取模型权重
loaded_dict = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_dict
loaded_model.cuda()
- 多卡保存模型
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
# 读取整个模型
loaded_model = torch.load(save_dir)
loaded_model = loaded_model.module #不同之处
# 读取模型权重(推荐)
loaded_dict = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model = nn.DataParallel(loaded_model).cuda() #不同之处
loaded_model.state_dict = loaded_dict
# 读取模型权重(其他方法1)
from collections import OrderedDict
loaded_dict = torch.load(save_dir)
# 去除module字段
new_state_dict = OrderedDict()
for k, v in loaded_dict.items():
name = k[7:] # module字段在最前面,从第7个字符开始就可以去掉module
new_state_dict[name] = v #新字典的key值对应的value一一对应
# 其他与单卡保存模型一致
loaded_model = models.resnet152()
loaded_model.state_dict = new_state_dict
loaded_model = loaded_model.cuda()
# 读取模型权重(其他方法2)
loaded_model = models.resnet152()
loaded_dict = torch.load(save_dir)
loaded_model.load_state_dict({
k.replace('module.', ''): v for k, v in loaded_dict.items()})
loaded_model = loaded_model.cuda()
多卡加载
- 单卡存储模型
用nn.DataParallel函数进行分布式训练设置即可
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2' #这里替换成希望使用的GPU编号
# 读取整个模型
loaded_model = torch.load(save_dir)
loaded_model = nn.DataParallel(loaded_model).cuda()#不同处
# 读取模型权重
loaded_dict = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_dict
loaded_model = nn.DataParallel(loaded_model).cuda()#不同处
- 多卡存储模型
建议仅存储权重,与单卡无异。若只有整个模型,则需要如下代码:
import os
import torch
from torchvision import models
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' #这里替换成希望使用的GPU编号
loaded_whole_model = torch.load(save_dir)
loaded_model = models.resnet152() #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_whole_model.state_dict
loaded_model = nn.DataParallel(loaded_model).cuda()
参考
边栏推荐
- Copying open source for basic software is not advisable. Self reliance is the right way
- manjaro easyconnecy报错:libgtk-x11-2.0.so.0: cannot open shared object file: No such file or directory
- How to install SSL certificates in Microsoft Exchange 2010
- SHAREit实力出众,登陆全球 IAP 实力榜 Top7
- flutter 系列之:flutter 中常用的 GridView layout 详解
- Ipetronik data acquisition equipment and softing q-vision software are committed to ADAS test scheme
- Deep understanding of Bayes theorem
- Microservice stability guarantee
- ASP. NET CORE Study04
- Validateur async. Vérificateur de données JS
猜你喜欢

Après avoir échoué à l'examen d'entrée à l'Université de technologie de Harbin, vous devez rester à l'Université en tant que « chercheur » après avoir obtenu votre diplôme.

ASP. NET CORE Study03

unity发布 webgl在手机端 inputfield唤醒键盘输入

中二青年付杰的逆袭故事:从二本生到 ICLR 杰出论文奖,我用了20年

Xiaobai's e-commerce business is very important to choose the right mall system!
![[today in history] June 28: musk was born; Microsoft launches office 365; The inventor of Chua's circuit was born](/img/bf/09ccf36caec099098a22f0e8b670bd.png)
[today in history] June 28: musk was born; Microsoft launches office 365; The inventor of Chua's circuit was born
![[unity Editor Extension Foundation], editorguilayout (I)](/img/f2/42413a4135fd6181bf311b685504b2.png)
[unity Editor Extension Foundation], editorguilayout (I)

SHAREit实力出众,登陆全球 IAP 实力榜 Top7

Go language learning notes - Gorm usage - database configuration, table addition | web framework gin (VII)

企业源代码保密方案分享
随机推荐
Matplotlib_Study01
Tips for using ugui (V) using scroll rect component
Unity webgl mobile end removal warning
Deep understanding of Bayes theorem
如何在Microsoft Exchange 2010中安装SSL证书
杰理之wif 干扰蓝牙【篇】
Finereport installation tutorial
杰理之wif 干扰蓝牙【篇】
The paging style flex is set to be displayed at the end (even if the number of pages is longer, there will be no line breaks at the end)
华泰证券手机app下载 怎么办理开户最安全
UDP传输rtp数据包丢帧
group_ Concat learning and configuration
ASP.NET CORE Study08
ASP.NET CORE Study06
Evaluation of IP location query interface I
从SimpleKV到Redis
flink核心之watermarker
Unity Editor Extension Foundation, GUI
go template with...end遍历用法
.NET混合开发解决方案24 WebView2对比CefSharp的超强优势