当前位置:网站首页>20.nn.Module
20.nn.Module
2022-07-31 13:13:00 【派大星的最爱海绵宝宝】
目录
所有网络层次的父类,当我们实现自己的层时必须继承这个类。
pytorch官方已经写好这个类了。
在初始化init中完成自己要定义的逻辑,并且在forward中完成一个计算图的前向构建的过程。
使用
nn.Linear
nn.BatchNorm2d
nn.Conv2d
这些类都继承自nn.Module
多层嵌套使用nn.module
使用nn.module好处
1.基本层
在nn.module中提供了大量的基本层
linear
relu
sigmod
conv2d
convtranspose2d
dropout
etc.
直接调用初始化函数,在调用.方法来调用forward函数,就可以使用这个层的功能
2.container
最核心的功能。也就是nn.sequential功能
使用这个容器,在这里不仅可以调用自己写的函数,也可以调用pytorch自带的。使用self.net可以把里面的函数依次执行。
self.net=nn.Sequential(
nn.Conv2d(1,32,5,1,1),
nn.MaxPool2d(2,2),
nn.RelU(True),
nn.BatchNorm2d(32),
nn.Conv2d(32,64,3,1,1),
nn.RelU(True),
nn.BatchNorm2d(128)
)
3.parameters
使用nn.module,会对网络的参数进行有效的管理,不需要额外管理参数。
net=nn.Sequential(nn.Linear(8,6),nn.Linear(4,8))
print(list(net.parameters())[0].shape)
print(list(net.parameters())[1].shape)
print(list(net.parameters())[2].shape)
print(list(net.parameters())[3].shape)
print(list(net.named_parameters())[0])
print(list(net.named_parameters())[1])
print(dict(net.named_parameters()).items())
[0]是weight,[1]是bias。对于w来说,输入的时候顺序是channel-out,channel-in。
两种形式,一个是不带名字。还有一个pytorch自带生成的名字,第一个是名字,第二个是类型。
4.modules
对内部的module使用很方便
modules是每个module直接子节点和子节点的所有子节点,即所有节点。
children是指直接子节点。
basicnet有两个参数,relu没有参数,linear两个参数。
child0是net自己本身。child1就是sequential。
5.to(device)
把一个类的所有内部tensor和操作转移到gpu或者cpu。
device=torch.device('cuda')
net=Net()
net.to(device)
cuda是cpu,net.to()可以转移到c或者gpu上,会返回一个net引用,这个net引用跟原来的net引用(net=Net())一模一样。
a.to返回的是a-gpu,这两个a是不一样的,前面的a是cpu上的reference,后面的a是gpu上的。
6.save and load
check pond就是网络的一个中间状态,每隔一段时间就会保存。
net.load_state_dict(torch.load('ckpt.mdl'))
# train...
torch.save(net.state_dict(),'ckpt.dml')
使用当前的类(net)提供的state_dict(),即把当前所有的状态返回,再使用torch.save方法把它保存到文件(copy.mdl)中。这是save操作。
在模型开始的时候,检查一下有没有checkpond。如果有就加载,首先在文件中使用load方法把它加载为pytorch类,再使用load_state_dict类,把这个数据加载到module,把网络中的参数的值都初始化为我们train好的值。这样就不需要重新初始化并且从0开始train。
7.train/test
一键进行train和test状态的切换。
# train...
net.train()
# test
net.eval()
8.执行自己的类
nn.relu是一个class,F.relu是一个functional,只有class才能写到nn.sequential中。
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self,input):
return input.view(input.size[0],-1)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net=nn.Sequential(
nn.Conv2d(1,16,stride=1,padding=1),
nn.MaxPool2d(2,2),
Flatten(),
nn.Linear(1*14*14,10)
)
def forward(self,x):
return self.net(x)
class MyLinear(nn.Module):
def __init__(self,outp,inp):
super(MyLinear, self).__init__()
# requires_grad=True
self.w=nn.Parameter(torch.randn(outp,inp))
self.b=nn.Parameter(torch.randn(outp))
def forward(self,x):
x = [email protected].w.t()+self.b
return x
我们需要打平,在flatten类中的,view,语言把第一个纬度保留,-1表示所有纬度的合在一起。这个类使用很广泛。
nn.Parameter可作为一个安装器可以把你的tensor做一个包装,只要所有的tensor经过它包装后,tensor就会自动加到模型的nn.parameter()方法中,可以自动的被SGD优化。如果不使用nn.Parameter,也需要设置grad。
nn.Parameter自动设置你的tensor需要梯度。
forward中一定要返回。
边栏推荐
- Productivity Tools and Plugins
- Two methods of NameNode failure handling
- IDEA can't find the Database solution
- 文本相似度计算(中英文)详解实战
- Grab the tail of gold, silver and silver, unlock the programmer interview "Artifact of Brushing Questions"
- golang-gin-优雅重启
- ECCV2022:在Transformer上进行递归,不增参数,计算量还少!
- 中望3D 2023正式发布,设计仿真制造一体化缩短产品开发周期
- Network layer key protocol - IP protocol
- Optimization of five data submission methods
猜你喜欢
ERROR 2003 (HY000) Can‘t connect to MySQL server on ‘localhost3306‘ (10061)解决办法
Ali on three sides: MQ message loss, repetition, backlog problem, how to solve?
抓住金三银四的尾巴,解锁程序员面试《刷题神器》
ECCV 2022 | 机器人的交互感知与物体操作
C# using NumericUpDown control
Detailed explanation of network protocols and related technologies
Spark学习:为Spark Sql添加自定义优化规则
NameNode (NN) 和SecondaryNameNode (2NN)工作机制
Network layer key protocol - IP protocol
Reasons and solutions for Invalid bound statement (not found)
随机推荐
C#使用NumericUpDown控件
FastAPI 封装一个通用的response
PyQt5快速开发与实战10.2 复利计算 && 10.3 刷新博客点击量
matlab as(assert dominance)
网络层重点协议——IP协议
How IDEA runs web programs
C#控件ListView用法
FastAPI encapsulates a generic response
Using SQL Server FOR XML and FOR JSON syntax on other RDBMSs with jOOQ
抓住金三银四的尾巴,解锁程序员面试《刷题神器》
基本语法(一)
Golang - gin - pprof - use and safety
Spark学习:为Spark Sql添加自定义优化规则
golang-gin - graceful restart
JSP中如何借助response对象实现页面跳转呢?
网络协议及相关技术详解
centos7安装mysql5.7
C# using NumericUpDown control
图像大面积缺失,也能逼真修复,新模型CM-GAN兼顾全局结构和纹理细节
Flutter keyboard visibility