当前位置:网站首页>Hands on deep learning -- discarding method and its code implementation
Hands on deep learning -- discarding method and its code implementation
2022-06-12 08:14:00 【Orange acridine 21】
One 、 The law of abandonment (dropout)
motivation : A good model needs to be robust to the disturbance of input data
Using noisy data is equivalent to Tikhonov Regular
The law of abandonment : Add noise between layers .
Use discard method , The discard method is usually used to hide the output of the connection layer .
The probability of discarding is a super parameter that controls the complexity of the model .
Two 、 Discard method is implemented from zero
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
def dropout(X,drop_prob):
X=X.float()
assert 0<=drop_prob <=1
keep_prob =1- drop_prob
if keep_prob==0:
return torch.zeros_like(X)
mask=(torch.randn(X.shape)<keep_prob).float()
return mask *X /keep_prob
X = torch.arange(16).view(2, 8)
dropout(X, 0)
dropout(X,0.5)
dropout(X,1.0)
"""
Define model parameters
"""
num_inputs,num_outputs,num_hiddens1,num_hiddens2=784,10,256,256
W1=torch.tensor(np.random.normal(0,0.01,size=(num_inputs,num_hiddens1)),dtype=torch.float,requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True,dtype=torch.float)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1,num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True,dtype=torch.float)
W3=torch.tensor(np.random.normal(0,0.01,size=(num_hiddens2,num_outputs)),dtype=torch.float,requires_grad=True)
b3=torch.zeros(num_outputs,requires_grad=True,dtype=torch.float)
params=[W1,b1,W2,b2,W3,b3]
"""
Defining models
"""
drop_prob1,drop_prob2=0.2,0.5
def net(X,is_training=True):
X = X.view(-1, num_inputs)
H1=(torch.matmul(X,W1)+b1).relu()
""" Use discard only during training """
if is_training:
H1 = dropout(H1, drop_prob1) # In the ⼀ Add discard layer after layer full connection
H2 = (torch.matmul(H1, W2) + b2).relu()
if is_training:
H2 = dropout(H2, drop_prob2) # In the ⼆ Add discard layer after layer full connection
return torch.matmul(H2, W3) + b3
"""
Training and testing models
"""
num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs,batch_size, params, lr)
3、 ... and 、 The discarding method is implemented concisely
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
def dropout(X, drop_prob):
X = X.float()
assert 0 <= drop_prob <= 1
keep_prob = 1 - drop_prob
if keep_prob == 0:
return torch.zeros_like(X)
mask = (torch.randn(X.shape) < keep_prob).float()
return mask * X / keep_prob
X = torch.arange(16).view(2, 8)
dropout(X, 0)
dropout(X, 0.5)
dropout(X, 1.0)
"""
Define model parameters
"""
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True, dtype=torch.float)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True, dtype=torch.float)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True, dtype=torch.float)
params = [W1, b1, W2, b2, W3, b3]
drop_prob1,drop_prob2=0.2,0.5
net = nn.Sequential(
d2l.FlattenLayer(),
nn.Linear(num_inputs, num_hiddens1),
nn.ReLU(),
nn.Dropout(drop_prob1),
nn.Linear(num_hiddens1, num_hiddens2),
nn.ReLU(),
nn.Dropout(drop_prob2),
nn.Linear(num_hiddens2, 10) )
for param in net.parameters():
nn.init.normal_(param, mean=0, std=0.01)
"""
Train and test models
"""
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs,
batch_size, None, None, optimizer)
边栏推荐
- js中的正则表达式
- ctfshow web3
- EasyExcel导出Excel表格到浏览器,并通过Postman测试导出Excel【入门案例】
- 模型压缩 | TIP 2022 - 蒸馏位置自适应:Spot-adaptive Knowledge Distillation
- Learning notes (1): live broadcast by Dr. Lu Qi - face up to challenges and grasp entrepreneurial innovation opportunities - face up to challenges and grasp entrepreneurial innovation opportunities -1
- 千万别把MES只当做工具,不然会错过最重要的东西
- (P25-P26)基于非范围的for循环、基于范围的for循环需要注意的3个细节
- DUF:Deep Video Super-Resolution Network Using Dynamic Upsampling Filters ... Reading notes
- (P14)overrid关键字的使用
- MYSQL中的调用存储过程,变量的定义,
猜你喜欢

Prediction of COVID-19 by RNN network

Model Trick | CVPR 2022 Oral - Stochastic Backpropagation A Memory Efficient Strategy

Database connection pool and dbutils tool

Vscode的Katex问题:ParseError: KaTeX Parse Error: Can‘t Use Function ‘$‘ In Math Mode At Position ...

工厂的生产效益,MES系统如何提供?

安科瑞电动机保护器具有过载反时限、过载定时限、接地、起动超时、漏电、欠载、断相、堵转等功能
![Easyexcel exports excel tables to the browser, and exports excel through postman test [introductory case]](/img/ca/0e2bd54a842a393231ec6db5ab02c2.png)
Easyexcel exports excel tables to the browser, and exports excel through postman test [introductory case]

Cookies and sessions

(p25-p26) three details of non range based for loop and range based for loop

Explanation and explanation on the situation that the volume GPU util (GPU utilization) is very low and the memory ueage (memory occupation) is very high during the training of pytoch
随机推荐
Leetcode notes: Weekly contest 276
C # push box
js中的数组
MATLAB image processing - Otsu threshold segmentation (with code)
StrVec类 移动拷贝
Quaternion Hanmilton and JPL conventions
Py&GO编程技巧篇:逻辑控制避免if else
Leetcode notes: Weekly contest 275
C # hide the keyboard input on the console (the input content is not displayed on the window)
In depth learning, the parameter quantity (param) in the network is calculated. The appendix contains links to floating point computations (flops).
2.2 linked list - Design linked list (leetcode 707)
Special notes on using NAT mode in VM virtual machine
MYSQL中的调用存储过程,变量的定义,
Literature reading: deep neural networks for YouTube recommendations
Gtest/gmock introduction and Practice
(p33-p35) lambda expression syntax, precautions for lambda expression, essence of lambda expression
Final review of Discrete Mathematics (predicate logic, set, relation, function, graph, Euler graph and Hamiltonian graph)
只把MES当做工具?看来你错过了最重要的东西
Prediction of COVID-19 by RNN network
How to write simple music program with MATLAB