当前位置:网站首页>20.nn.Module
20.nn.Module
2022-07-31 13:13:00 【派大星的最爱海绵宝宝】
目录
所有网络层次的父类,当我们实现自己的层时必须继承这个类。
pytorch官方已经写好这个类了。
在初始化init中完成自己要定义的逻辑,并且在forward中完成一个计算图的前向构建的过程。
使用
nn.Linear
nn.BatchNorm2d
nn.Conv2d
这些类都继承自nn.Module
多层嵌套使用nn.module
使用nn.module好处
1.基本层
在nn.module中提供了大量的基本层
linear
relu
sigmod
conv2d
convtranspose2d
dropout
etc.
直接调用初始化函数,在调用.方法来调用forward函数,就可以使用这个层的功能
2.container
最核心的功能。也就是nn.sequential功能
使用这个容器,在这里不仅可以调用自己写的函数,也可以调用pytorch自带的。使用self.net可以把里面的函数依次执行。
self.net=nn.Sequential(
nn.Conv2d(1,32,5,1,1),
nn.MaxPool2d(2,2),
nn.RelU(True),
nn.BatchNorm2d(32),
nn.Conv2d(32,64,3,1,1),
nn.RelU(True),
nn.BatchNorm2d(128)
)
3.parameters
使用nn.module,会对网络的参数进行有效的管理,不需要额外管理参数。
net=nn.Sequential(nn.Linear(8,6),nn.Linear(4,8))
print(list(net.parameters())[0].shape)
print(list(net.parameters())[1].shape)
print(list(net.parameters())[2].shape)
print(list(net.parameters())[3].shape)
print(list(net.named_parameters())[0])
print(list(net.named_parameters())[1])
print(dict(net.named_parameters()).items())
[0]是weight,[1]是bias。对于w来说,输入的时候顺序是channel-out,channel-in。
两种形式,一个是不带名字。还有一个pytorch自带生成的名字,第一个是名字,第二个是类型。
4.modules
对内部的module使用很方便
modules是每个module直接子节点和子节点的所有子节点,即所有节点。
children是指直接子节点。
basicnet有两个参数,relu没有参数,linear两个参数。
child0是net自己本身。child1就是sequential。
5.to(device)
把一个类的所有内部tensor和操作转移到gpu或者cpu。
device=torch.device('cuda')
net=Net()
net.to(device)
cuda是cpu,net.to()可以转移到c或者gpu上,会返回一个net引用,这个net引用跟原来的net引用(net=Net())一模一样。
a.to返回的是a-gpu,这两个a是不一样的,前面的a是cpu上的reference,后面的a是gpu上的。
6.save and load
check pond就是网络的一个中间状态,每隔一段时间就会保存。
net.load_state_dict(torch.load('ckpt.mdl'))
# train...
torch.save(net.state_dict(),'ckpt.dml')
使用当前的类(net)提供的state_dict(),即把当前所有的状态返回,再使用torch.save方法把它保存到文件(copy.mdl)中。这是save操作。
在模型开始的时候,检查一下有没有checkpond。如果有就加载,首先在文件中使用load方法把它加载为pytorch类,再使用load_state_dict类,把这个数据加载到module,把网络中的参数的值都初始化为我们train好的值。这样就不需要重新初始化并且从0开始train。
7.train/test
一键进行train和test状态的切换。
# train...
net.train()
# test
net.eval()
8.执行自己的类
nn.relu是一个class,F.relu是一个functional,只有class才能写到nn.sequential中。
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self,input):
return input.view(input.size[0],-1)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net=nn.Sequential(
nn.Conv2d(1,16,stride=1,padding=1),
nn.MaxPool2d(2,2),
Flatten(),
nn.Linear(1*14*14,10)
)
def forward(self,x):
return self.net(x)
class MyLinear(nn.Module):
def __init__(self,outp,inp):
super(MyLinear, self).__init__()
# requires_grad=True
self.w=nn.Parameter(torch.randn(outp,inp))
self.b=nn.Parameter(torch.randn(outp))
def forward(self,x):
x = [email protected].w.t()+self.b
return x
我们需要打平,在flatten类中的,view,语言把第一个纬度保留,-1表示所有纬度的合在一起。这个类使用很广泛。
nn.Parameter可作为一个安装器可以把你的tensor做一个包装,只要所有的tensor经过它包装后,tensor就会自动加到模型的nn.parameter()方法中,可以自动的被SGD优化。如果不使用nn.Parameter,也需要设置grad。
nn.Parameter自动设置你的tensor需要梯度。
forward中一定要返回。
边栏推荐
- Usage of += in C#
- Introduction to using NPM
- TensorRT安装及使用教程「建议收藏」
- C# control ListView usage
- Selenium IDE for Selenium Automation Testing
- IDEA找不到Database解决方法
- networkx绘制度分布
- Four ways to clear the float and its principle understanding
- PHP Serialization: eval
- PyQt5 rapid development and actual combat 9.7 Automated testing of UI layer
猜你喜欢

centos7安装mysql5.7步骤(图解版)

爱可可AI前沿推介(7.31)
![LRU缓存[线性表 -> 链表 -> hash定位 -> 双向链表]](/img/ad/dd80541514d6fedde8c730218fdf5a.png)
LRU缓存[线性表 -> 链表 -> hash定位 -> 双向链表]

365天挑战LeetCode1000题——Day 044 最大层内元素和 层次遍历

The operator,

networkx绘制度分布

Spark学习:为Spark Sql添加自定义优化规则

Introduction to using NPM

Install the latest pytorch gpu version

ECCV2022: Recursion on Transformer without adding parameters and less computation!
随机推荐
IDEA的database使用教程(使用mysql数据库)
使用openssl命令生成证书和对应的私钥,私钥签名,公钥验签
How to quickly split and merge cell data in Excel
Edge Cloud Explained in Simple Depth | 4. Lifecycle Management
MATLAB | 我也做了一套绘图配色可视化模板
阿里三面:MQ 消息丢失、重复、积压问题,怎么解决?
生产力工具和插件
求一份常见Oracle故障模拟场景
Using SQL Server FOR XML and FOR JSON syntax on other RDBMSs with jOOQ
Spark学习:为Spark Sql添加自定义优化规则
Introduction to the PartImageNet Semantic Part Segmentation dataset
golang八股文整理(持续搬运)
golang-gin-优雅重启
C#Assembly的使用
PyQt5 rapid development and actual combat 9.7 Automated testing of UI layer
ERROR 2003 (HY000) Can‘t connect to MySQL server on ‘localhost3306‘ (10061)解决办法
ICML2022 | 面向自监督图表示学习的全粒度自语义传播
alert(1) (haozi.me)靶场练习
CentOS7 —— yum安装mysql
CentOS7 安装MySQL 图文详细教程