当前位置:网站首页>Pytorch学习笔记(二)神经网络的使用
Pytorch学习笔记(二)神经网络的使用
2022-07-26 13:29:00 【小胡今天有变强吗】
神经网络的基本骨架–nn.Moudle的使用
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input):
output = input + 1
return output
tudui = Tudui()
x = torch.tensor(1)
output = tudui(x)
print(output)
卷积操作
import torch
import torch.nn.functional as F
input = torch.tensor([[1, 2, 0, 3, 1],
[0, 1, 2, 3, 1],
[1, 2, 1, 0, 0],
[5, 2, 3, 1, 1],
[2, 1, 0 ,1, 1]])
kernel = torch.tensor([[1, 2, 1],
[0, 1, 0],
[2, 1, 0]])
input = torch.reshape(input, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))
print(input.shape)
print(kernel.shape)
output = F.conv2d(input, kernel, stride=1)
print(output)
output2 = F.conv2d(input, kernel, stride=2)
print(output2)
output3 = F.conv2d(input, kernel, stride=1, padding=1)
print(output3)

其中,input是输入图像,kernel是卷积核的大小,stride是步长,padding是填充的距离。
神经网络-卷积层
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("../data_conv2d", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x)
return x
tudui = Tudui()
writer = SummaryWriter("../logs")
step = 0
for data in dataloader:
imgs, targets = data
output = tudui(imgs)
print(imgs.shape)
print(output.shape)
# torch.Size([64, 3, 32, 32])
writer.add_images("input", imgs, step)
# torch.Size([64, 6, 30, 30]) -> [xxx, 3, 30, 30]
output = torch.reshape(output, (-1, 3, 30, 30))
writer.add_images("output", output, step)
step += 1

神经网络-最大池化的使用
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("../data", train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)
# input = torch.tensor([[1, 2,0, 3, 1],
# [0, 1, 2, 3, 1],
# [1, 2, 1, 0, 0],
# [5, 2, 3, 1, 1],
# [2, 1, 0, 1, 1]], dtype=torch.float32)
# input = torch.reshape(input, (-1, 1, 5, 5))
# print(input.shape)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=True)
def forward(self, input):
output = self.maxpool1(input)
return output
tudui = Tudui()
# output = tudui(input)
# print(output)
writer = SummaryWriter("../logs_maxpool")
step = 0
for data in dataloader:
imgs, targets = data
writer.add_images("input", imgs, step)
output = tudui(imgs)
writer.add_images("output", output, step)
step += 1
writer.close()

神经网络-非线性激活
import torch
from torch import nn
from torch.nn import ReLU
input = torch.tensor([[1, -0.5],
[-1, 3]])
input = torch.reshape(input, (-1, 1, 2, 2))
print(input.shape)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.relu1 = ReLU()
def forward(self, input):
output = self.relu1(input)
return output
tudui = Tudui()
output = tudui(input)
print(output)

线性层
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.linear1 = Linear(196608, 10)
def forward(self, input):
output = self.linear1(input)
return output
tudui = Tudui()
for data in dataloader:
imgs, target = data
print(imgs.shape)
output = torch.flatten(imgs)
print(output.shape)
output = tudui(output)
print(output.shape)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1MwdaWtd-1658803724138)(C:\Users\Husheng\Desktop\学习笔记\image-20220724114411297.png)]](/img/8c/abef5abaf51b0a39020df10b290f99.png)
搭建实战与Sequential的使用

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = Conv2d(3, 32, 5, padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32, 32, 5, padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32, 64, 5, padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
self.linear1 = Linear(1024, 64)
self.linear2 = Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
tudui = Tudui()
print(tudui)
input = torch.ones((64, 3, 32, 32))
output = tudui(input)
print(output.shape)
注意padding=2,需要通过下面公式计算:



使用Sequential:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
# self.conv1 = Conv2d(3, 32, 5, padding=2)
# self.maxpool1 = MaxPool2d(2)
# self.conv2 = Conv2d(32, 32, 5, padding=2)
# self.maxpool2 = MaxPool2d(2)
# self.conv3 = Conv2d(32, 64, 5, padding=2)
# self.maxpool3 = MaxPool2d(2)
# self.flatten = Flatten()
# self.linear1 = Linear(1024, 64)
# self.linear2 = Linear(64, 10)
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
# x = self.conv1(x)
# x = self.maxpool1(x)
# x = self.conv2(x)
# x = self.maxpool2(x)
# x = self.conv3(x)
# x = self.maxpool3(x)
# x = self.flatten(x)
# x = self.linear1(x)
# x = self.linear2(x)
x = self.model1(x)
return x
tudui = Tudui()
print(tudui)
input = torch.ones((64, 3, 32, 32))
output = tudui(input)
print(output.shape)
writer = SummaryWriter("../logs_seq")
writer.add_graph(tudui, input)
writer.close()


双击模型中的模块可以展开细节:
损失函数与反向传播
import torch
from torch.nn import L1Loss
inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))
loss = L1Loss(reduction='sum')
result = loss(inputs, targets)
print(result)

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
imgs, targets = data
outputs = tudui(imgs)
result_loss = loss(outputs, targets)
print(result_loss)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qvBOz5kq-1658803724144)(C:\Users\Husheng\Desktop\学习笔记\image-20220724163551355.png)]](/img/67/cb2dbfb22f0c2d9e7694464f3ccfe1.png)
优化器
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
running_loss = 0.0
for data in dataloader:
imgs, targets = data
outputs = tudui(imgs)
result_loss = loss(outputs, targets)
optim.zero_grad()
result_loss.backward()
optim.step()
running_loss = running_loss + result_loss
print(running_loss)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p25M6pGe-1658803724144)(C:\Users\Husheng\Desktop\学习笔记\image-20220725101701396.png)]](/img/59/2305bafef741feb1542b196c7d7bd5.png)
参考资料
边栏推荐
- B+ tree (3) clustered index, secondary index -- MySQL from entry to proficiency (XV)
- Abstract factory and its improvement examples
- Codeforces Round #810 (Div. 2)【比赛记录】
- 估值15亿美元的独角兽被爆裁员,又一赛道遇冷?
- JS object assignment problem
- MVVM architecture encapsulation of kotlin series (kotlin+mvvm)
- Emotion analysis model based on Bert
- [flower carving hands-on] interesting and fun music visualization series small project (13) -- organic rod column lamp
- panic: Error 1045: Access denied for user ‘root‘@‘117.61.242.215‘ (using password: YES)
- 详解关系抽取模型 CasRel
猜你喜欢

解决远程主机无法连接mysql数据库的问题

《Kotlin系列》之MVVM架构封装(kotlin+mvvm)

Unicode file parsing methods and existing problems
Exploration on cache design optimization of community like business

MVVM architecture encapsulation of kotlin series (kotlin+mvvm)

AI theory knowledge map 1 Foundation

估值15亿美元的独角兽被爆裁员,又一赛道遇冷?

Probability theory and mathematical statistics

Team research and development from ants' foraging process (Reprint)

Hcip day 12 notes sorting (BGP Federation, routing rules)
随机推荐
B+树(3)聚簇索引,二级索引 --mysql从入门到精通(十五)
Leetcode 1523. count odd numbers within the interval
图扑 3D 可视化国风设计 | 科技与文化碰撞炫酷”火花“
Detailed relation extraction model casrel
Comparator (interface between comparable and comparator)
一文学透MySQL表的创建和约束
官宣!艾德韦宣集团与百度希壤达成深度共创合作
历时15年、拥有5亿用户的飞信,彻底死了
如何构建以客户为中心的产品蓝图:来自首席技术官的建议
MySQL data directory (1) -- database structure (24)
华为机考 ~ 偏移量实现字符串加密
WPS凭什么拒绝广告?
Basic sentence structure of English ----- origin
Leetcode 217. there are duplicate elements
Hcip day 11 comparison (BGP configuration and release)
7-25 0-1 backpack (50 points)
Huawei computer test ~ offset realizes string encryption
Click El dropdown item/@click.native
Familiarize you with the "phone book" of cloud network: DNS
详解关系抽取模型 CasRel