当前位置:网站首页>Pytoch learning notes -- Summary of common functions 2
Pytoch learning notes -- Summary of common functions 2
2022-07-25 15:41:00 【whut_ L】
Catalog
1-- torch.nn.Sequential( ) function
3--Model.train() Functions and Model.eval() function
1-- torch.nn.Sequential( ) function
effect : adopt Sequential sequence Create neural network structure
Example : call Block Function time , The in parentheses will be executed sequentially Function.
import torch
x = torch.randn(1, 1, 64, 64)
Block = torch.nn.Sequential(
torch.nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1),
torch.nn.ReLU()
)
result = Block(x)
print(result)The above code is equivalent to :
import torch
Conv1 = torch.nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1)
Conv2 = torch.nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1),
Relu = torch.nn.ReLU()
x = torch.randn(1, 1, 64, 64)
print(x)
x = Conv1(x)
x = Relu(x)
x = Conv2(x)
result = Relu(x)
print(result)############################
2--torch.flatten() function
effect : Splicing by dimension ;torch.flatten(input, start_dim = 0, end_dim = -1);input For input tensor data ,start_dim Is the starting dimension of splicing ,end_dim Is the termination dimension of splicing .
Example :
import torch
x = torch.randn(2, 3, 3)
print(x)
result1 = torch.flatten(x, start_dim = 0, end_dim = 2)
print(result1)
result2 = torch.flatten(x, start_dim = 0, end_dim = 1)
print(result2)
result3 = torch.flatten(x, start_dim = 1, end_dim = 2)
print(result3)result :
tensor([[[ 0.3546, -0.8551, 2.3490],
[-0.0920, 0.0773, -0.4556],
[-1.6943, 1.4517, -0.0767]],
[[-0.6950, 0.4382, -1.2691],
[-0.0252, -0.4980, -0.5994],
[-0.2581, -0.2544, -0.6787]]]) #X
tensor([ 0.3546, -0.8551, 2.3490, -0.0920, 0.0773, -0.4556, -1.6943, 1.4517,
-0.0767, -0.6950, 0.4382, -1.2691, -0.0252, -0.4980, -0.5994, -0.2581,
-0.2544, -0.6787]) #result1
tensor([[ 0.3546, -0.8551, 2.3490],
[-0.0920, 0.0773, -0.4556],
[-1.6943, 1.4517, -0.0767],
[-0.6950, 0.4382, -1.2691],
[-0.0252, -0.4980, -0.5994],
[-0.2581, -0.2544, -0.6787]]) #result2
tensor([[ 0.3546, -0.8551, 2.3490, -0.0920, 0.0773, -0.4556, -1.6943, 1.4517,
-0.0767],
[-0.6950, 0.4382, -1.2691, -0.0252, -0.4980, -0.5994, -0.2581, -0.2544,
-0.6787]]) #result3############################
3--Model.train() Functions and Model.eval() function
Model.train() Function is used for the training state of the model , Place the model in Training mode ;
Model.eval() Function is used for the test state of the model , Place the model in Test mode ;
The model is in two different modes of training and testing , Some functions have different functions , Such as :Batch Normalization and Dropout( Training mode will be enabled , In test mode, it is disabled );
If there are BN layer (Batch Normalization) and Dropout, You need to add model.train().model.train() It's a guarantee BN Layer can use The mean and variance of each batch of data . about Dropout,model.train() It's random Take part of the network connection To train and update parameters .
If there are BN layer (Batch Normalization) and Dropout, Add... When testing model.eval().model.eval() It's a guarantee BN The layer can be used Mean and variance of all training data , That is to say, in the process of testing, we should ensure that BN The mean and variance of the layers are constant . about Dropout,model.eval() Is to use All network connections , That is, there is no random abandonment of neurons .
After training train After sample , The resulting model model To test samples . stay model(test) Before , Need to add model.eval(), Otherwise , There is input data , Even without training , It also changes the weights . This is a model contains BN Layer and the Dropout The nature of that .
Reference link
边栏推荐
- 2021HNCPC-E-差分,思维
- 2021 Shanghai match-h-two point answer
- 哪里有搭建flink cdc抽mysql数的demo?
- Pytorch学习笔记-刘二老师RNN高级篇-代码注释及结果
- Leetcode - 379 telephone directory management system (Design)
- Leetcode - 359 log rate limiter (Design)
- 分布式 | 实战:将业务从 MyCAT 平滑迁移到 dble
- Notes on inputview and inputaccessoryview of uitextfield
- 2021 Jiangsu race a Array line segment tree, maintain value range, Euler power reduction
- C#精挑整理知识要点9 集合2(建议收藏)
猜你喜欢

LeetCode - 379 电话目录管理系统(设计)

伤透脑筋的CPU 上下文切换

Leetcode - 380 o (1) time to insert, delete and get random elements (design hash table + array)

JVM - classloader and parental delegation model

Pytorch学习笔记--常用函数总结3

Box avoiding mouse

MATLAB读取显示图像时数据格式转换原因

Leetcode - 362 knock counter (Design)

LeetCode - 622 设计循环队列 (设计)

Pat grade a 1152 Google recruitment (20 points)
随机推荐
LeetCode - 380 O(1) 时间插入、删除和获取随机元素 (设计 哈希表+数组)
Cf750f1 thinking DP
分布式 | 实战:将业务从 MyCAT 平滑迁移到 dble
C # carefully sorting out key points of knowledge 11 entrustment and events (recommended Collection)
LeetCode - 232 用栈实现队列 (设计 双栈实现队列)
C#精挑整理知识要点10 泛型(建议收藏)
C # fine sorting knowledge points 9 Set 2 (recommended Collection)
带你创建你的第一个C#程序(建议收藏)
自定义注解校验API参数电话号
2021江苏省赛A. Array-线段树,维护值域,欧拉降幂
Leetcode - 380 o (1) time to insert, delete and get random elements (design hash table + array)
MySQL—常用SQL语句整理总结
<栈模拟递归>
PAT甲级1153 Decode Registration Card of PAT (25 分)
wait()和sleep()的区别理解
Geogle Colab笔记1--运行Geogle云端硬盘上的.py文件
IOS interview questions
Gary marcus: learning a language is more difficult than you think
2021 Shanghai sai-d-cartland number variant, DP
JVM knowledge brain map sharing