当前位置:网站首页>Pytorch学习记录(四):过拟合、卷积神经网络CNN
Pytorch学习记录(四):过拟合、卷积神经网络CNN
2022-07-28 19:42:00 【狸狸Arina】
文章目录
1. 过拟合
1.1 过拟合和欠拟合
- 欠拟合
- 模型复杂度比真实数据的复杂度小,模型的表达能力不够;
- 训练和测试的效果均不好;

- 过拟合
- 模型的复杂度比真实数据的复杂度高;
- 训练时表现较好,测试时表现差,泛化能力不好;

- 总结

1.2 Train-Val-Test划分
- val 验证集:用于挑选模型参数;
- test 测试集:用于评价模型的泛化能力,不能用于训练和评估。一般test 测试集不会给出;
- train 训练集:用于模型训练;
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
batch_size=200
learning_rate=0.01
epochs=10
train_db = datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
train_db,
batch_size=batch_size, shuffle=True)
test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(test_db,
batch_size=batch_size, shuffle=True)
print('train:', len(train_db), 'test:', len(test_db))
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))
train_loader = torch.utils.data.DataLoader(
train_db,
batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(
val_db,
batch_size=batch_size, shuffle=True)
1.3 K-fold cross-validation
- 将train-val数据集合并分成K份,依次取第i份(i<=K)作为验证集,剩下的作为训练集来训练;

1.4 Regularization
1.4.1 正则化的作用
- 迫使参数的范数接近于0(预测曲线变得平滑,迫使多项式前几项比较大-预测结果好,多项式后几项比较小,退化成较小次方的模型),减小模型复杂度;

1.4.2 L1 Regularization
- pytorch中没有直接实现L1正则的类或函数,需要手动实现;


1.4.3 L2 Regularization
- torch.optim优化器能够实现L2正则化;
- optim.SGD()中的weight-decay参数等于L2-regularization中的λ参数;
device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)
1.5 动量与学习率衰减
1.5.1 Momentum


1.5.2 Learning rate decay
- 训练开始时设置比较大的学习率,然后逐步衰减学习率;

1.6 Early Stop,Dropout
1.6.1 Early Stop

1.6.2 Dropout
- 只在训练时使用,测试时停止;


2. 卷积神经网络CNN
2.1 nn.Conv2d
import torch
import torch.nn as nn
x = torch.randn(2,1,28,28)
layer1 = nn.Conv2d(1,3, kernel_size=3, stride=1,padding=0)
layer2 = nn.Conv2d(1,3,kernel_size=3, stride=1, padding=1)
layer3 = nn.Conv2d(1,3,kernel_size=3, stride=3, padding=1)
out1 = layer1.forward(x)
out2 = layer2.forward(x)
out3 = layer3.forward(x)
out = layer1(x) #调用__call__,会调用self.forward()方法 推荐使用
print(out1.shape)
print(out2.shape)
print(out3.shape)
print(out.shape)
print(layer1.weight.shape) #weight:out_channel, in_channel, kernel_size, kernel_size
print(layer1.bias.shape) #bias: out_channel x1x1
''' torch.Size([2, 3, 26, 26]) torch.Size([2, 3, 28, 28]) torch.Size([2, 3, 10, 10]) torch.Size([2, 3, 26, 26]) torch.Size([3, 1, 3, 3]) torch.Size([3]) '''
2.2 F.conv2d
import torch
import torch.nn.functional as F
x = torch.randn(2,1,28,28)
w = torch.rand(16,1,3,3) # weight:out_channel, in_channel, kernel_size, kernel_size
b = torch.rand(16) # #bias: out_channel x1x1
out1 = F.conv2d(x, w, b, stride=1, padding=0)
out2 = F.conv2d(x, w, b, stride=1, padding=1)
out3 = F.conv2d(x, w, b, stride=2, padding=1)
print(out1.shape)
print(out2.shape)
print(out3.shape)
''' torch.Size([2, 16, 26, 26]) torch.Size([2, 16, 28, 28]) torch.Size([2, 16, 14, 14]) '''
2.3 池化层与采样
2.3.1 下采样
- Max Pooling

- Avg pooling

import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.randn(2,1,28,28)
layer1 = nn.MaxPool2d(2, stride=2)
layer2 = nn.AvgPool2d(2, stride=2)
out1 = layer1(x)
out2 = layer2(x)
print(out1.shape)
print(out2.shape)
out3 = F.max_pool2d(x, kernel_size = 2, stride = 2)
out4 = F.avg_pool2d(x, kernel_size = 2, stride = 2)
print(out3.shape)
print(out4.shape)
''' torch.Size([2, 1, 14, 14]) torch.Size([2, 1, 14, 14]) torch.Size([2, 1, 14, 14]) torch.Size([2, 1, 14, 14]) '''
2.3.2 上采样
- F.interpolate
import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.randn(2,1,5,5)
out1 = F.interpolate(x, scale_factor=2, mode='nearest')
out2 = F.interpolate(x, scale_factor=3, mode='nearest')
print(out1.shape)
print(out2.shape)
''' torch.Size([2, 1, 10, 10]) torch.Size([2, 1, 15, 15]) '''
2.4 Batch Norm

2.4.1 BN计算
- 每个channel上的数据,减去这个channel上统计得到的均值,除以这个channel上统计得到的方差,使其逼近与0-1正态分布。再加上γ 和 β得到γ- β的正态分布;
- γ和 β是可学习的,是需要梯度信息的;

2.4.2 nn.BatchNorm2d
import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.randn(10,16,784) # b,channels,Hw
layer = nn.BatchNorm1d(16) #channels
out = layer(x)
print(layer.running_mean)
print(layer.running_var)
# nn.BatchNorm1d.running_mean代表当前batch的均值;
# nn.BatchNorm1d.running_var代表当前batch的方差;
# nn.BatchNorm1d.weight代表γ;
# nn.BatchNorm1d.bias代表β;
''' tensor([-0.0004, 0.0008, 0.0012, 0.0005, -0.0004, -0.0002, 0.0022, -0.0008, 0.0015, 0.0004, 0.0003, -0.0002, 0.0006, -0.0022, 0.0014, 0.0003]) tensor([1.0006, 0.9976, 0.9995, 0.9994, 0.9999, 1.0003, 0.9987, 1.0002, 1.0002, 1.0013, 0.9999, 1.0033, 0.9980, 0.9984, 1.0010, 0.9991]) '''
2.4.3 nn.BatchNorm2d
import torch
import torch.nn as nn
import torch.nn.functional as F
x = torch.randn(10,16,28,28)
layer = nn.BatchNorm2d(16)
out = layer(x)
print(layer.running_mean.shape)
print(layer.running_var.shape)
print(layer.weight.shape)
print(layer.bias.shape)
print(vars(layer))
''' torch.Size([16]) torch.Size([16]) torch.Size([16]) torch.Size([16]) {'training': True, '_parameters': OrderedDict([('weight', Parameter containing: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)), ('bias', Parameter containing: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True))]), '_buffers': OrderedDict([('running_mean', tensor([ 1.8079e-03, -2.5026e-04, -1.4939e-03, -6.3105e-04, 6.8191e-04, 1.0124e-04, -2.2681e-03, 2.3068e-03, 8.3116e-04, -1.3128e-03, -2.1227e-05, 3.3749e-04, 2.6776e-03, 8.1243e-04, 1.2933e-03, -5.1558e-04])), ('running_var', tensor([0.9990, 1.0008, 1.0006, 0.9975, 1.0025, 0.9980, 1.0019, 1.0009, 1.0009, 1.0007, 0.9982, 1.0010, 1.0012, 1.0017, 1.0007, 1.0021])), ('num_batches_tracked', tensor(1))]), '_non_persistent_buffers_set': set(), '_backward_hooks': OrderedDict(), '_is_full_backward_hook': None, '_forward_hooks': OrderedDict(), '_forward_pre_hooks': OrderedDict(), '_state_dict_hooks': OrderedDict(), '_load_state_dict_pre_hooks': OrderedDict(), '_modules': OrderedDict(), 'num_features': 16, 'eps': 1e-05, 'momentum': 0.1, 'affine': True, 'track_running_stats': True} '''
2.5 经典卷积网络
2.5.1 LeNet-5 1980
- 80年代卷积神经网络的发明,用于手写数字识别;
- 2convs + 3fc 一共五层;

2.5.2 AlexNet 2012
- 一共有8层 5convs+3fc;
- 在2块 GTX580上训练的,使用两组卷积参数,再合并输出;
- 使用Pooling操作,ReLU,Dropout;

2.5.3 VGG 2014
- 共有6种网络结构;
- 使用堆叠的小尺度卷积核替代大的卷积核,不会损失精度,而计算更快;
- 1x1卷积:更少的计算量,能够改变输出特征图的通道数;
- 11-19层,网络层数比AlexNet更深;

2.5.4 GoogleNet 2014
- 一共22层,对同一层,使用多种卷积核;

2.5.5 ResNet 2016
- 单纯地堆叠网络层并不会使得网络的性能得到提升,因为随着网络深度加深,会存在梯度消失/梯度弥散的情况,使得网络参数得不到更新;

- 先通过1x1卷积降低特征图通道数,然后使用3x3卷积,最后使用1x1卷积将特征图通道数转换为残差输入的特征通道数,这样就减少了计算量,使得堆叠更深的网络成为了可能;

2.5.6 DenseNet

2.6 nn.Module
- 所有网络层类的父类。如果要实现自定义的网络,则必须继承这个类;
- module可以嵌套module;
2.6.1 Module容器 nn.Sequential

2.6.2 Module参数管理

2.6.3 Module内部的Modules管理
from re import L
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Linear(4,3)
def forward(self, x):
return self.net(x)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
MyModule(),
nn.ReLU(inplace=True),
nn.Linear(3, 2)
)
def forward(self, x):
return self.net(x)
model = Net()
# print(dict(model.named_parameters()).items())
print('modules........')
print(list(model.modules()))
print(len(list(model.modules())))
print('children.......')
print(list(model.children()))
print(len(list(model.children())))
''' modules........ [Net( (net): Sequential( (0): MyModule( (net): Linear(in_features=4, out_features=3, bias=True) ) (1): ReLU(inplace=True) (2): Linear(in_features=3, out_features=2, bias=True) ) ), Sequential( (0): MyModule( (net): Linear(in_features=4, out_features=3, bias=True) ) (1): ReLU(inplace=True) (2): Linear(in_features=3, out_features=2, bias=True) ), MyModule( (net): Linear(in_features=4, out_features=3, bias=True) ), Linear(in_features=4, out_features=3, bias=True), ReLU(inplace=True), Linear(in_features=3, out_features=2, bias=True)] 6 children....... [Sequential( (0): MyModule( (net): Linear(in_features=4, out_features=3, bias=True) ) (1): ReLU(inplace=True) (2): Linear(in_features=3, out_features=2, bias=True) )] 1 '''
2.6.4 Module转移

2.6.5 Module参数保存和加载

2.6.6 train/test状态切换

2.6.7 实现自定义

2.7 数据增强
2.7.1 Flip


2.7.2 Rotate


2.7.3 Scale


2.7.4 Crop Part


2.7.5 Noise

边栏推荐
- Go并发编程基础
- DELTA热金属检测器维修V5G-JC-R1激光测量传感器/检测仪原理分析
- 百度搜索符合预期,但涉及外链黑帽策略,什么原因?
- Cobal Strike的学习与使用
- Buuctf questions upload labs record pass-01~pass-10
- Eureka相互注册,只显示对方或只在一个中显示问题
- Uncaught Error:Invalid geoJson format Cannot read property ‘length‘ of undefind
- Eureka registers with each other, only showing each other or only showing problems in one
- 速卖通测评自养号,国外环境如何搭建?需要多少成本?
- IJCAI2022教程 | 对话推荐系统
猜你喜欢

New development of letinar in Korea: single lens 4.55G, light efficiency up to 10%

The ref value ‘xxx‘ will likely have changed by the time this effect function runs.If this ref......
![[input ID number] is replaced by an asterisk, and input is cut into multiple small squares (similar)](/img/f0/7e3ea94e02a42b6055c40b58d1e39c.png)
[input ID number] is replaced by an asterisk, and input is cut into multiple small squares (similar)

Uncaught Error:Invalid geoJson format Cannot read property ‘length‘ of undefind

Jiuxin intelligence officially joined opengauss community

(转)冒泡排序及优化详解

Ctfshow network lost track record (1)

Ijcai2022 tutorial | dialogue recommendation system

Coding with these 16 naming rules can save you more than half of your comments!

How to measure software architecture
随机推荐
【Bluetooth蓝牙开发】八、BLE协议之传输层
属性基加密仿真及代码实现(CP-ABE)论文:Ciphertext-Policy Attribute-Based Encryption
What functions does MySQL have? Don't look everywhere. Just look at this.
DLL decompile (decompile encrypted DLL)
百度搜索符合预期,但涉及外链黑帽策略,什么原因?
九鑫智能正式加入openGauss社区
There have been two safety accidents in a month after listing. Is L9 ideal?
The ref value ‘xxx‘ will likely have changed by the time this effect function runs.If this ref......
Unity knowledge points summary (1)
BUUCTF做题Upload-Labs记录pass-11~pass-20
编码用这16个命名规则能让你少写一半以上的注释!
Young freshmen yearn for more open source | here comes the escape guide from open source to employment!
Capture video by buffering
(转)冒泡排序及优化详解
After Europe, it entered Japan and South Korea again, and the globalization of Pico consumer VR accelerated
Several skills of API interface optimization
How to measure software architecture
Applet container technology improves mobile R & D efficiency by 500%
Mobilevit: challenge the end-to-side overlord of mobilenet
Paging function (board)