当前位置:网站首页>pytorch(网络模型)
pytorch(网络模型)
2022-06-26 05:30:00 【月屯】
神经网络鸡翅nn.Module

官网
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))# 卷积、非线性处理
return F.relu(self.conv2(x))
练习
import torch
import torch.nn as nn
import torch.nn.functional as F
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
return x+1
dun=Dun()
x=torch.tensor(1.0)# 转化类型
output=dun(x);# 调用forward
print(output)# 输出
卷积层


import torch
input=torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
kernel=torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
print(input.shape)# 输出尺寸
print(kernel.shape)
input=torch.reshape(input,(1,1,5,5))# 类型转换
kernel=torch.reshape(kernel,(1,1,3,3))# 类型转换
print(input)
print(kernel)
print(input.shape)
print(kernel.shape)
# 卷积操作
out= F.conv2d(input,kernel,stride=1)
print(out)
out= F.conv2d(input,kernel,stride=2)
print(out)
# 填充
out= F.conv2d(input,kernel,stride=1,padding=1)
print(out)
输出chanel是2时

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(dataset,batch_size=64)
# 卷积类
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
def forward(self,x):
return self.conv1(x)
dun=Dun()
print(dun)
writer= SummaryWriter("./logs")
step=0
# 卷积操作
for data in dataloader:
img,target=data
output=dun(img)
print(img.shape)
print(output.shape)
writer.add_images("input",img,step)
output=torch.reshape(output,(-1,3,30,30))# -1时会根据后面的值自动计算
writer.add_images("output",output,step)
step+=1
writer.close()


池化层
作用:就像高清视频换成低清视频



import torch
from torch import nn
from torch.nn import MaxPool2d
input =torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]],dtype=torch.float32)
input=torch.reshape(input,(-1,1,5,5))
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.maxpool=MaxPool2d(kernel_size=3,ceil_mode=True)# ceil_model false和True的结果和预期的一致
def forward(self,inut):
return self.maxpool(input)
dun=Dun()
out=dun(input)
print(out)

图片处理
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.maxpool=MaxPool2d(kernel_size=3,ceil_mode=True)# ceil_model false和True的结果和预期的一致
def forward(self,input):
return self.maxpool(input)
dun=Dun()
step=0
writer=SummaryWriter("logs")
for data in dataloader:
img,target=data
writer.add_images("input",img,step)
output=dun(img)
writer.add_images("output",output,step)
step+=1
writer.close()

非线性激活
非线性变换目的是引入非线性特征,可以更好地处理信息
ReLU

import torch
from torch import nn
from torch.nn import ReLU
input= torch.tensor([[1,-0.5],[-1,3]])
input=torch.reshape(input,(-1,1,2,2))
print(input.shape)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.relu1=ReLU()
def forward(self,input):
return self.relu1(input)
dun=Dun()
output=dun(input)
print(output)

sigmoid

import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,download=True,transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.relu1=ReLU()
self.sigmoid=Sigmoid()
def forward(self,input):
return self.sigmoid(input)
dun=Dun()
writer=SummaryWriter("./logs")
step=0
for data in dataloader:
img,target=data
writer.add_images("input",img,global_step=step)
output=dun(img)
writer.add_images("output",output,global_step=step)
step+=1
writer.close()

线性层


import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.linear=Linear(196608,10)
def forward(self,input):
return self.linear(input)
dun=Dun()
for data in dataloader:
img,target=data
print(img.shape)
# input=torch.reshape(img,(1,1,1,-1))
input= torch.flatten(img)# 将数据展平一行,可以代替上面的一行
print(input.shape)
output=dun(input)
print(output.shape)

正则化层
加快神经网络地训练速度
# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)
其他层有Recurrent Layers、Transformer Layers、Linear Layers等
简单的网络模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
# 1.
# self.conv1=Conv2d(3,32,5,padding=2)
# self.maxpool1=MaxPool2d(2)
# self.conv2=Conv2d(32,32,5,padding=2)
# self.maxpool2=MaxPool2d(2)
# self.conv3=Conv2d(32,64,5,padding=2)
# self.maxpool3=MaxPool2d(2)
# self.flatten=Flatten()
# self.linear1=Linear(1024,64)
# self.linear2=Linear(64,10)
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()
# 测试
input=torch.ones((64,3,32,32))
print(dun(input).shape)
writer=SummaryWriter("./logs")
writer.add_graph(dun,input)
writer.close()


loss function
L1Loss、MSELoss
import torch
from torch.nn import L1Loss
from torch import nn
input=torch.tensor([1,2,3],dtype=torch.float32)
targrt=torch.tensor([1,2,5],dtype=torch.float32)
loss=L1Loss(reduction="sum")# 该参数有sum和mean两种,默认是mean
print(loss(input,targrt))
loss_mse=nn.MSELoss()
print(loss_mse(input,targrt))

CROSSENTROPYLOSS

import torch
from torch.nn import L1Loss
from torch import nn
x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
print(loss_cross(x,y))

使用
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset=dataset,batch_size=1)
# 分类神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()
loss=nn.CrossEntropyLoss()
for data in dataloader:
img,target=data
output=dun(img)
print(output)
print(target)
result_loss=loss(output,target)# 损失函数估计
#print(result_loss)
result_loss.backward()# 反向传播
print("ok")
优化器
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset=dataset,batch_size=1)
# 分类神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()# 分类神经网络实例
loss=nn.CrossEntropyLoss() # 损失函数
optim=torch.optim.SGD(dun.parameters(),lr=0.01) # 优化器
# 做20次训练
for epoch in range(20):
running_loss=0.0
for data in dataloader:
img,target=data
output=dun(img)
result_loss=loss(output,target)# 损失函数估计
optim.zero_grad()# 之前的梯度清零
#print(result_loss)
result_loss.backward()# 反向传播,求出每个节点的梯度
optim.step()# 参数调优
running_loss+=result_loss
print(running_loss)

网络模型保存和读取
模型保存
import torch
import torchvision
from torch import nn
vgg16=torchvision.models.vgg16(pretrained=False)
# 保存方式1(保存了网络模型的结构和参数)
torch.save(vgg16,"vgg16_method1.pth")
# 方式2:保存模型的参数(以字典的方式)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#陷阱
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
return self.conv1(x)
dun=Dun()
torch.save(dun,"dun_method1.pth")
模型加载
import torch
#保存方式1-》加载
# model=torch.load("vgg16_method1.pth")
# print(model)
# 方式二加载,只是字典模式
# model=torch.load("vgg16_method2.pth")
# print(model)
## 若要恢复网络模型
import torchvision
from torch import nn
vgg16=torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
# 陷阱
# 需要写出模型
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
return self.conv1(x)
model=torch.load("dun_method1.pth")
print(model)
边栏推荐
- Tp5.0 framework PDO connection MySQL error: too many connections solution
- 【活动推荐】云原生、产业互联网、低代码、Web3、元宇宙……哪个是 2022 年架构热点?...
- Procedural life
- ZigBee learning in simple terms lesson 3 external interruption
- Internship May 29, 2019
- How to ensure the efficiency and real-time of pushing large-scale group messages in mobile IM?
- 写在父亲节前
- The wechat team disclosed that the wechat interface is stuck with a super bug "15..." The context of
- Tp5.0框架 PDO连接mysql 报错:Too many connections 解决方法
- SDN based DDoS attack mitigation
猜你喜欢
Protocol selection of mobile IM system: UDP or TCP?

cartographer_fast_correlative_scan_matcher_2d分支定界粗匹配

cartographer_optimization_problem_2d

Could not get unknown property ‘*‘ for SigningConfig container of type org.gradle.api.internal

Ad tutorial series | 4 - creating an integration library file

cartographer_pose_graph_2d

Baidu API map is not displayed in the middle, but in the upper left corner. What's the matter? Resolved!

慢慢学JVM之缓存行和伪共享
![C# 40. Byte[] to hexadecimal string](/img/3e/1b8b4e522b28eea4faca26b276a27b.png)
C# 40. Byte[] to hexadecimal string

Command line interface of alluxio
随机推荐
第九章 设置结构化日志记录(一)
9 common classes
创建 SSH 秘钥对 配置步骤
Create SSH key pair configuration steps
Vie procédurale
使用Jenkins执行TestNg+Selenium+Jsoup自动化测试和生成ExtentReport测试报告
Apktool tool usage document
Internship May 29, 2019
Recursively traverse directory structure and tree presentation
写在父亲节前
[upsampling method opencv interpolation]
递归遍历目录结构和树状展现
Protocol selection of mobile IM system: UDP or TCP?
Replacing domestic image sources in openwrt for soft routing (take Alibaba cloud as an example)
程序人生
Two step processing of string regular matching to get JSON list
Leetcode513.找出树的左下角的值
Mongodb image configuration method
Serious hazard warning! Log4j execution vulnerability is exposed!
《财富自由之路》读书之一点体会