当前位置:网站首页>pytorch(网络模型)
pytorch(网络模型)
2022-06-26 05:30:00 【月屯】
神经网络鸡翅nn.Module

官网
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))# 卷积、非线性处理
return F.relu(self.conv2(x))
练习
import torch
import torch.nn as nn
import torch.nn.functional as F
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
return x+1
dun=Dun()
x=torch.tensor(1.0)# 转化类型
output=dun(x);# 调用forward
print(output)# 输出
卷积层


import torch
input=torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
kernel=torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
print(input.shape)# 输出尺寸
print(kernel.shape)
input=torch.reshape(input,(1,1,5,5))# 类型转换
kernel=torch.reshape(kernel,(1,1,3,3))# 类型转换
print(input)
print(kernel)
print(input.shape)
print(kernel.shape)
# 卷积操作
out= F.conv2d(input,kernel,stride=1)
print(out)
out= F.conv2d(input,kernel,stride=2)
print(out)
# 填充
out= F.conv2d(input,kernel,stride=1,padding=1)
print(out)
输出chanel是2时

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(dataset,batch_size=64)
# 卷积类
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
def forward(self,x):
return self.conv1(x)
dun=Dun()
print(dun)
writer= SummaryWriter("./logs")
step=0
# 卷积操作
for data in dataloader:
img,target=data
output=dun(img)
print(img.shape)
print(output.shape)
writer.add_images("input",img,step)
output=torch.reshape(output,(-1,3,30,30))# -1时会根据后面的值自动计算
writer.add_images("output",output,step)
step+=1
writer.close()


池化层
作用:就像高清视频换成低清视频



import torch
from torch import nn
from torch.nn import MaxPool2d
input =torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]],dtype=torch.float32)
input=torch.reshape(input,(-1,1,5,5))
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.maxpool=MaxPool2d(kernel_size=3,ceil_mode=True)# ceil_model false和True的结果和预期的一致
def forward(self,inut):
return self.maxpool(input)
dun=Dun()
out=dun(input)
print(out)

图片处理
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.maxpool=MaxPool2d(kernel_size=3,ceil_mode=True)# ceil_model false和True的结果和预期的一致
def forward(self,input):
return self.maxpool(input)
dun=Dun()
step=0
writer=SummaryWriter("logs")
for data in dataloader:
img,target=data
writer.add_images("input",img,step)
output=dun(img)
writer.add_images("output",output,step)
step+=1
writer.close()

非线性激活
非线性变换目的是引入非线性特征,可以更好地处理信息
ReLU

import torch
from torch import nn
from torch.nn import ReLU
input= torch.tensor([[1,-0.5],[-1,3]])
input=torch.reshape(input,(-1,1,2,2))
print(input.shape)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.relu1=ReLU()
def forward(self,input):
return self.relu1(input)
dun=Dun()
output=dun(input)
print(output)

sigmoid

import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,download=True,transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.relu1=ReLU()
self.sigmoid=Sigmoid()
def forward(self,input):
return self.sigmoid(input)
dun=Dun()
writer=SummaryWriter("./logs")
step=0
for data in dataloader:
img,target=data
writer.add_images("input",img,global_step=step)
output=dun(img)
writer.add_images("output",output,global_step=step)
step+=1
writer.close()

线性层


import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.linear=Linear(196608,10)
def forward(self,input):
return self.linear(input)
dun=Dun()
for data in dataloader:
img,target=data
print(img.shape)
# input=torch.reshape(img,(1,1,1,-1))
input= torch.flatten(img)# 将数据展平一行,可以代替上面的一行
print(input.shape)
output=dun(input)
print(output.shape)

正则化层
加快神经网络地训练速度
# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)
其他层有Recurrent Layers、Transformer Layers、Linear Layers等
简单的网络模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
# 1.
# self.conv1=Conv2d(3,32,5,padding=2)
# self.maxpool1=MaxPool2d(2)
# self.conv2=Conv2d(32,32,5,padding=2)
# self.maxpool2=MaxPool2d(2)
# self.conv3=Conv2d(32,64,5,padding=2)
# self.maxpool3=MaxPool2d(2)
# self.flatten=Flatten()
# self.linear1=Linear(1024,64)
# self.linear2=Linear(64,10)
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()
# 测试
input=torch.ones((64,3,32,32))
print(dun(input).shape)
writer=SummaryWriter("./logs")
writer.add_graph(dun,input)
writer.close()


loss function
L1Loss、MSELoss
import torch
from torch.nn import L1Loss
from torch import nn
input=torch.tensor([1,2,3],dtype=torch.float32)
targrt=torch.tensor([1,2,5],dtype=torch.float32)
loss=L1Loss(reduction="sum")# 该参数有sum和mean两种,默认是mean
print(loss(input,targrt))
loss_mse=nn.MSELoss()
print(loss_mse(input,targrt))

CROSSENTROPYLOSS

import torch
from torch.nn import L1Loss
from torch import nn
x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
print(loss_cross(x,y))

使用
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset=dataset,batch_size=1)
# 分类神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()
loss=nn.CrossEntropyLoss()
for data in dataloader:
img,target=data
output=dun(img)
print(output)
print(target)
result_loss=loss(output,target)# 损失函数估计
#print(result_loss)
result_loss.backward()# 反向传播
print("ok")
优化器
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset=dataset,batch_size=1)
# 分类神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()# 分类神经网络实例
loss=nn.CrossEntropyLoss() # 损失函数
optim=torch.optim.SGD(dun.parameters(),lr=0.01) # 优化器
# 做20次训练
for epoch in range(20):
running_loss=0.0
for data in dataloader:
img,target=data
output=dun(img)
result_loss=loss(output,target)# 损失函数估计
optim.zero_grad()# 之前的梯度清零
#print(result_loss)
result_loss.backward()# 反向传播,求出每个节点的梯度
optim.step()# 参数调优
running_loss+=result_loss
print(running_loss)

网络模型保存和读取
模型保存
import torch
import torchvision
from torch import nn
vgg16=torchvision.models.vgg16(pretrained=False)
# 保存方式1(保存了网络模型的结构和参数)
torch.save(vgg16,"vgg16_method1.pth")
# 方式2:保存模型的参数(以字典的方式)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#陷阱
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
return self.conv1(x)
dun=Dun()
torch.save(dun,"dun_method1.pth")
模型加载
import torch
#保存方式1-》加载
# model=torch.load("vgg16_method1.pth")
# print(model)
# 方式二加载,只是字典模式
# model=torch.load("vgg16_method2.pth")
# print(model)
## 若要恢复网络模型
import torchvision
from torch import nn
vgg16=torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
# 陷阱
# 需要写出模型
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
return self.conv1(x)
model=torch.load("dun_method1.pth")
print(model)
边栏推荐
- Tp5.0 framework PDO connection MySQL error: too many connections solution
- DOM文档
- 转帖——不要迷失在技术的海洋中
- SOFA Weekly | 开源人—于雨、本周 QA、本周 Contributor
- 12 multithreading
- cartographer_local_trajectory_builder_2d
- How to rewrite a pseudo static URL created by zenpart
- 使用Jedis監聽Redis Stream 實現消息隊列功能
- Mongodb image configuration method
- Mysql 源码阅读(二)登录连接调试
猜你喜欢

cartographer_ local_ trajectory_ builder_ 2d

慢慢学JVM之缓存行和伪共享

The parameter field of the callback address of the payment interface is "notify_url", and an error occurs after encoding and decoding the signed special character URL (,,,,,)

DOM文档

How to ensure the efficiency and real-time of pushing large-scale group messages in mobile IM?

Leetcode114. 二叉树展开为链表

【红队】要想加入红队,需要做好哪些准备?

SOFA Weekly | 开源人—于雨、本周 QA、本周 Contributor

Ribbon负载均衡服务调用

Command line interface of alluxio
随机推荐
Ribbon负载均衡服务调用
Protocol selection of mobile IM system: UDP or TCP?
线程优先级
Replacing domestic image sources in openwrt for soft routing (take Alibaba cloud as an example)
MySQL数据库-01数据库概述
Thoughts triggered by the fact that app applications are installed on mobile phones and do not display icons
uniCloud云开发获取小程序用户openid
Summary of the 10th provincial Blue Bridge Cup
Positioning setting horizontal and vertical center (multiple methods)
售前分析
[MySQL] MySQL million level data paging query method and its optimization
The parameter field of the callback address of the payment interface is "notify_url", and an error occurs after encoding and decoding the signed special character URL (,,,,,)
cartographer_ backend_ constraint
The difference between get and post in small interview questions
Introduction to GUI programming to game practice (I)
ZigBee explain in simple terms lesson 2 hardware related and IO operation
Recursively traverse directory structure and tree presentation
How to rewrite a pseudo static URL created by zenpart
Internship May 29, 2019
基于SDN的DDoS攻击缓解