当前位置:网站首页>1.4 nn. Module neural network (II)
1.4 nn. Module neural network (II)
2022-07-29 03:22:00 【smiling0927】
Loss function :
1.
nn.MSELoss Used to calculate the mean square error
nn.CrossEntropyLoss Used to calculate cross entropy loss .
eg:
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch as t
class Net(nn.Module):
def __init__(self):
#nn.Module The function of the subclass must execute the constructor of the parent class in the constructor
# The following formula is equivalent to nn.Module.__init__(self)
super(Net,self).__init__()
# Convolution layer ‘1’ The input image is a single channel ,‘6’ Indicates the number of output channels ,‘5’ The convolution kernel is 5*5
self.conv1=nn.Conv2d(1,6,5)
# Convolution layer
self.conv2 = nn.Conv2d (6, 16, 5)
# Affine layer 、 Fully connected layer ,y=wx+b
# General definition of a linear Layer time , It's written as nn.Linear(in_features,out_features)
self.fc1=nn.Linear(16*5*5,120)
self.fc2 = nn.Linear (120, 84)
self.fc3 = nn.Linear (84, 10)
def forward(self,x):
# Convolution --> Activate --> Pooling
x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d (F.relu (self.conv2 (x)), 2)
#reshape,‘-1’ It means adaptive
# This sentence usually appears in model Class forward Function , The specific location is usually before calling the classifier . Classifier is a simple nn.Linear() structure ,
# The input and output are all values with one dimension ,x = x.view(x.size(0), -1) The emergence of this sentence is to integrate the previous multi-dimensional tensor Flatten into one dimension
#view() The functional root of a function reshape similar , To convert size size .x = x.view(batchsize, -1) in batchsize Refers to the number of lines after conversion ,
# and -1 Without telling the function how many columns there are , According to the original tensor Data and batchsize Auto assign columns .
x=x.view(x.size()[0],-1)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
net=Net()
input=Variable(t.randn(1,1,32,32))
out = net(input)
print(out)
net.zero_grad()
out.backward(t.randn(1,10))
output = net(input)
target = Variable(t.randn(1,10))# Hypothetical target :1,2,3,4,5,6,7,8,9,10
criterion = nn.MSELoss()
loss = criterion(output,target)
print(loss)
out:
tensor([[-0.1351, -0.0720, -0.0460, 0.1184, -0.0810, 0.0441, -0.0008,
-0.0082, -0.1029, 0.0620]])
tensor(2.6400)Be careful :target = Variable(t.arange(1,11))# Hypothetical target :1,2,3,4,5,6,7,8,9,10, The output is :tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]), It's a vector , but output Output is 1*10 Matrix , Therefore, there will be mismatches . Report errors :RuntimeError: input and target shapes do not match: input [1 x 10], target [10] at /opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/THNN/generic/MSECriterion.c:13.
But books and some blogs do have normal output , Therefore, the change point needs to be studied .
2. If for loss Back propagation ,( Use grad_fn attribute ), You can see its calculation diagram . When calling loss.backward() when , The graph will be generated and automatically differentiated , It will also calculate the derivative of the parameters in the calculation diagram .
print(loss)
# Application .backward, Observe before and after the call grad
print(loss.grad_fn)
print(loss.grad_fn.next_functions)
print(loss.grad_fn.next_functions[0][0].next_functions)
net.zero_grad()# Gradient zeroing of all parameters
print(' Before back propagation conv1.bias Gradient of ')
print(net.conv1.bias.grad)
loss.backward()
print(' After back propagation conv1.bias Gradient of ')
print(net.conv1.bias.grad)out:
<MseLossBackward object at 0x7f80b5af9320>
((<AddmmBackward object at 0x7f80b5728cc0>, 0),)
((<ExpandBackward object at 0x7f80b5728cc0>, 0), (<ReluBackward object at 0x7f80b57009b0>, 0), (<TBackward object at 0x7f80f0ac7198>, 0))
Before back propagation conv1.bias Gradient of
tensor([ 0., 0., 0., 0., 0., 0.])
After back propagation conv1.bias Gradient of
tensor(1.00000e-02 *
[-0.2564, 0.5056, 0.5424, 0.2758, -1.5066, -0.0828])
边栏推荐
- 单例模式(饿汉式 懒汉式)
- CUDA GDB prompt: /tmp/tmpxft**** cudafe1.stub. c: No such file or directory.
- Wechat's crazy use of glide - life cycle learning
- Calculation of array serial number of force deduction questions (daily question 7/28)
- Data truncation and estimation
- Object转String的几种方法
- Verilog:阻塞赋值和非阻塞赋值
- 2022-07-28 study notes of group 4 self-cultivation class (every day)
- GJB常见混淆概念
- Shell programming specifications and variables
猜你喜欢

Hangao database best practice configuration tool Hg_ BP log collection content

A case of gradually analyzing the splitting of classes -- colorful ball collisions

STC MCU drive 1.8 'TFT SPI screen demonstration example (including data package)

2022-07-28 第四小组 修身课 学习笔记(every day)

Idea configuration web container and war packaging

Shardingsphere's level table practice (III)

Ten thousand words detailed Google play online application standard package format AAB

微信为之疯狂的Glide使用——之生命周期学习

Verilog:阻塞赋值和非阻塞赋值

Configure vscade to realize ROS writing
随机推荐
C traps and defects Chapter 3 semantic "traps" 3.4 avoid "couple method"
C obtains JSON format data asynchronously from the web address
C and pointer Chapter 3 semantic "trap" 3.5 null pointer is not a string
3D advanced renderer: artlandis studio 2021.2 Chinese version
01-sdram: Code of initialization module
微信为之疯狂的Glide使用——之生命周期学习
Multiline text omission
[technology 1]
【C】数组
Military product development process - transition phase
STC MCU drive 1.8 'TFT SPI screen demonstration example (including data package)
原理知识用得上
Tonight at 7:30 | is the AI world in the eyes of Lianjie, Jiangmen, Baidu and country garden venture capital continue to be advanced or return to the essence of business
mycat读写分离配置
A case of gradually analyzing the splitting of classes -- colorful ball collisions
Shell编程规范与变量
How to realize multi line annotation in MATLAB
MySQL流程控制之while、repeat、loop循环实例分析
Digital image processing Chapter 10 - image segmentation
Rongyun real-time community solution