当前位置:网站首页>20.nn.Module
20.nn.Module
2022-07-31 13:13:00 【派大星的最爱海绵宝宝】
目录
所有网络层次的父类,当我们实现自己的层时必须继承这个类。
pytorch官方已经写好这个类了。
在初始化init中完成自己要定义的逻辑,并且在forward中完成一个计算图的前向构建的过程。
使用
nn.Linear
nn.BatchNorm2d
nn.Conv2d
这些类都继承自nn.Module
多层嵌套使用nn.module
使用nn.module好处
1.基本层
在nn.module中提供了大量的基本层
linear
relu
sigmod
conv2d
convtranspose2d
dropout
etc.
直接调用初始化函数,在调用.方法来调用forward函数,就可以使用这个层的功能
2.container
最核心的功能。也就是nn.sequential功能
使用这个容器,在这里不仅可以调用自己写的函数,也可以调用pytorch自带的。使用self.net可以把里面的函数依次执行。
self.net=nn.Sequential(
nn.Conv2d(1,32,5,1,1),
nn.MaxPool2d(2,2),
nn.RelU(True),
nn.BatchNorm2d(32),
nn.Conv2d(32,64,3,1,1),
nn.RelU(True),
nn.BatchNorm2d(128)
)
3.parameters
使用nn.module,会对网络的参数进行有效的管理,不需要额外管理参数。
net=nn.Sequential(nn.Linear(8,6),nn.Linear(4,8))
print(list(net.parameters())[0].shape)
print(list(net.parameters())[1].shape)
print(list(net.parameters())[2].shape)
print(list(net.parameters())[3].shape)
print(list(net.named_parameters())[0])
print(list(net.named_parameters())[1])
print(dict(net.named_parameters()).items())
[0]是weight,[1]是bias。对于w来说,输入的时候顺序是channel-out,channel-in。
两种形式,一个是不带名字。还有一个pytorch自带生成的名字,第一个是名字,第二个是类型。
4.modules
对内部的module使用很方便
modules是每个module直接子节点和子节点的所有子节点,即所有节点。
children是指直接子节点。
basicnet有两个参数,relu没有参数,linear两个参数。
child0是net自己本身。child1就是sequential。
5.to(device)
把一个类的所有内部tensor和操作转移到gpu或者cpu。
device=torch.device('cuda')
net=Net()
net.to(device)
cuda是cpu,net.to()可以转移到c或者gpu上,会返回一个net引用,这个net引用跟原来的net引用(net=Net())一模一样。
a.to返回的是a-gpu,这两个a是不一样的,前面的a是cpu上的reference,后面的a是gpu上的。
6.save and load
check pond就是网络的一个中间状态,每隔一段时间就会保存。
net.load_state_dict(torch.load('ckpt.mdl'))
# train...
torch.save(net.state_dict(),'ckpt.dml')
使用当前的类(net)提供的state_dict(),即把当前所有的状态返回,再使用torch.save方法把它保存到文件(copy.mdl)中。这是save操作。
在模型开始的时候,检查一下有没有checkpond。如果有就加载,首先在文件中使用load方法把它加载为pytorch类,再使用load_state_dict类,把这个数据加载到module,把网络中的参数的值都初始化为我们train好的值。这样就不需要重新初始化并且从0开始train。
7.train/test
一键进行train和test状态的切换。
# train...
net.train()
# test
net.eval()
8.执行自己的类
nn.relu是一个class,F.relu是一个functional,只有class才能写到nn.sequential中。
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self,input):
return input.view(input.size[0],-1)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net=nn.Sequential(
nn.Conv2d(1,16,stride=1,padding=1),
nn.MaxPool2d(2,2),
Flatten(),
nn.Linear(1*14*14,10)
)
def forward(self,x):
return self.net(x)
class MyLinear(nn.Module):
def __init__(self,outp,inp):
super(MyLinear, self).__init__()
# requires_grad=True
self.w=nn.Parameter(torch.randn(outp,inp))
self.b=nn.Parameter(torch.randn(outp))
def forward(self,x):
x = [email protected].w.t()+self.b
return x
我们需要打平,在flatten类中的,view,语言把第一个纬度保留,-1表示所有纬度的合在一起。这个类使用很广泛。
nn.Parameter可作为一个安装器可以把你的tensor做一个包装,只要所有的tensor经过它包装后,tensor就会自动加到模型的nn.parameter()方法中,可以自动的被SGD优化。如果不使用nn.Parameter,也需要设置grad。
nn.Parameter自动设置你的tensor需要梯度。
forward中一定要返回。
边栏推荐
- 尚硅谷–MySQL–基础篇(P1~P95)
- SAP 电商云 Spartacus UI 和 Accelerator UI 里的 ASM 模块
- Spark学习:为Spark Sql添加自定义优化规则
- IDEA版Postman插件Restful Fast Request,细节到位,功能好用
- sqlalchemy determines whether a field of type array has at least one consistent data with an array
- C# 中的Async 和 Await 的用法详解
- Solution for browser hijacking by hao360
- 查看Oracle数据库的用户名和密码
- C# control ListView usage
- MATLAB | 我也做了一套绘图配色可视化模板
猜你喜欢

Introduction to using NPM

Anaconda安装labelImg图像标注软件

报错IDEA Terminated with exit code 1

PHP Serialization: eval

PyQt5 rapid development and actual combat 10.2 compound interest calculation && 10.3 refresh blog clicks

C# using NumericUpDown control

Introduction to the PartImageNet Semantic Part Segmentation dataset

C#使用ComboBox控件

NameNode (NN) and SecondaryNameNode (2NN) working mechanism

ERROR 2003 (HY000) Can‘t connect to MySQL server on ‘localhost3306‘ (10061)
随机推荐
EXCEL如何快速拆分合并单元格数据
SAP message TK 248 solved
The use of C# control CheckBox
matlab as(assert dominance)
[CPU Design Practice] Simple Pipeline CPU Design
C#使用ComboBox控件
Ali on three sides: MQ message loss, repetition, backlog problem, how to solve?
ECCV2022: Recursion on Transformer without adding parameters and less computation!
ASM module in SAP Ecommerce Cloud Spartacus UI and Accelerator UI
C#控件StatusStrip使用
Solution for browser hijacking by hao360
CentOS7 installation MySQL graphic detailed tutorial
Six Stones Programming: No matter which function you think is useless, people who can use it will not be able to leave, so at least 99%
基于模糊预测与扩展卡尔曼滤波的野值剔除方法
IDEA连接MySQL数据库并执行SQL查询操作
Verilog——基于FPGA的贪吃蛇游戏(VGA显示)
PyQt5 rapid development and actual combat 9.7 Automated testing of UI layer
/run/NetworkManager占用空间过大
The latest complete code: Incremental training using the word2vec pre-training model (two loading methods corresponding to two saving methods) applicable to various versions of gensim
go中select语句