当前位置:网站首页>Pytorch学习记录(三):随机梯度下降、神经网络与全连接
Pytorch学习记录(三):随机梯度下降、神经网络与全连接
2022-07-28 19:42:00 【狸狸Arina】
文章目录
1. 随机梯度下降
1.1 激活函数及其梯度
1.1.1 Sigmoid / Logistic


import torch
a = torch.linspace(-100,100,100,requires_grad=True)
b = torch.sigmoid(a)
print(b)
''' tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.8349e-39, 5.1533e-38, 3.8855e-37, 2.9296e-36, 2.2089e-35, 1.6655e-34, 1.2557e-33, 9.4681e-33, 7.1388e-32, 5.3825e-31, 4.0584e-30, 3.0599e-29, 2.3072e-28, 1.7396e-27, 1.3116e-26, 9.8893e-26, 7.4564e-25, 5.6220e-24, 4.2389e-23, 3.1961e-22, 2.4098e-21, 1.8169e-20, 1.3699e-19, 1.0329e-18, 7.7881e-18, 5.8721e-17, 4.4274e-16, 3.3382e-15, 2.5170e-14, 1.8978e-13, 1.4309e-12, 1.0789e-11, 8.1345e-11, 6.1333e-10, 4.6244e-09, 3.4867e-08, 2.6289e-07, 1.9822e-06, 1.4945e-05, 1.1267e-04, 8.4891e-04, 6.3653e-03, 4.6075e-02, 2.6696e-01, 7.3304e-01, 9.5392e-01, 9.9363e-01, 9.9915e-01, 9.9989e-01, 9.9999e-01, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00], grad_fn=<SigmoidBackward>) '''
1.1.2 Tanh


import torch
a = torch.linspace(-1,1,10,requires_grad=True)
b = torch.tanh(a)
print(b)
''' tensor([-0.7616, -0.6514, -0.5047, -0.3215, -0.1107, 0.1107, 0.3215, 0.5047, 0.6514, 0.7616], grad_fn=<TanhBackward>) '''
1.1.3 ReLU


import torch
import torch.nn.functional as F
a = torch.linspace(-1,1,10,requires_grad=True)
b = torch.relu(a)
c = F.relu(a)
print(b)
print(c)
''' tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1111, 0.3333, 0.5556, 0.7778, 1.0000], grad_fn=<ReluBackward0>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1111, 0.3333, 0.5556, 0.7778, 1.0000], grad_fn=<ReluBackward0>) '''
1.2 Loss函数及其梯度
1.2.1 MSE


1.2.2 autograd.grad()求梯度
import torch
import torch.nn.functional as F
x = torch.ones(1)
w = torch.full([1],2., requires_grad=True) #长度为1,值为2
mse = F.mse_loss(torch.ones(1), x*w) #第一个参数为predict值 第二个为label值
print(mse)
print(torch.autograd.grad(mse, [w]))
''' tensor(1., grad_fn=<MseLossBackward>) (tensor([2.]),) '''
1.2.3 loss.backward()求梯度
import torch
import torch.nn.functional as F
x = torch.ones(1)
w = torch.full([1],2., requires_grad=True) #长度为1,值为2
mse = F.mse_loss(torch.ones(1), x*w) #第一个参数为predict值 第二个为label值
print(mse)
mse.backward()
print(w.grad)
''' tensor(1., grad_fn=<MseLossBackward>) tensor([2.]) '''
1.2.4 Softmax
import torch
import torch.nn.functional as F
a = torch.rand(3)
a.requires_grad_()
print(a)
p = F.softmax(a, dim = 0)
print(p)
print(torch.autograd.grad(p[0], [a], retain_graph=True)) #动态图不会被清除,可以连续多次求梯度,或者多次backward
print(torch.autograd.grad(p[1], [a], retain_graph=True))
print(torch.autograd.grad(p[2], [a], retain_graph=True))
''' tensor([0.8659, 0.0540, 0.4153], requires_grad=True) tensor([0.4805, 0.2133, 0.3062], grad_fn=<SoftmaxBackward>) (tensor([ 0.2496, -0.1025, -0.1471]),) (tensor([-0.1025, 0.1678, -0.0653]),) (tensor([-0.1471, -0.0653, 0.2124]),) '''
2. 神经网络与全连接
2.1 Entropy


2.2 多分类问题
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
w1, b1 = torch.randn(200, 784, requires_grad=True),\
torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
torch.zeros(10, requires_grad=True)
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)
def forward(x):
x = [email protected].t() + b1
x = F.relu(x)
x = [email protected].t() + b2
x = F.relu(x)
x = [email protected].t() + b3
x = F.relu(x) #这里不激活也可以
return x
optimizer = torch.optim.SGD([w1,b1,w2,b2,w3,b3], lr = 1e-2)
criteon = nn.CrossEntropyLoss()
epoches = 10
batch_size = 200
minist_train = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]
))
minist_val = datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]
))
train_loader = torch.utils.data.DataLoader(minist_train, batch_size = batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(minist_val, batch_size = batch_size, shuffle = False)
for epoch in range(epoches):
for batch_idx, (data, target) in enumerate(train_loader):
data = data.view(-1, 28*28)
logits = forward(data)
loss =criteon(logits, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, \
batch_idx*len(data), len(train_loader.dataset), 100. *batch_idx/len(train_loader),
loss.item()))
test_loss = 0
total_correct = 0
for data, target in val_loader:
data = data.view(-1, 28*28)
logits = forward(data) #(N, classes)
loss = criteon(logits, target)
test_loss += loss.item()
pred = logits.data.max(dim=1)[1]
correct = pred.eq(target.data).sum()
total_correct += correct
test_loss /= len(val_loader.dataset)
accuracy = total_correct / len(val_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, total_correct, len(val_loader.dataset),
100. * accuracy))
''' ... Train epoch: 7 [0/60000 (0%)] Loss: 0.188797 Train epoch: 7 [20000/60000 (33%)] Loss: 0.157730 Train epoch: 7 [40000/60000 (67%)] Loss: 0.153730 Test set: Average loss: 0.0008, Accuracy: 9513/10000 (95%) Train epoch: 8 [0/60000 (0%)] Loss: 0.242635 Train epoch: 8 [20000/60000 (33%)] Loss: 0.092858 Train epoch: 8 [40000/60000 (67%)] Loss: 0.165861 Test set: Average loss: 0.0008, Accuracy: 9540/10000 (95%) Train epoch: 9 [0/60000 (0%)] Loss: 0.099372 Train epoch: 9 [20000/60000 (33%)] Loss: 0.118166 Train epoch: 9 [40000/60000 (67%)] Loss: 0.155070 Test set: Average loss: 0.0007, Accuracy: 9556/10000 (96%) '''
2.3 全连接层
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.model = nn.Sequential(
nn.Linear(784,200),
nn.ReLU(inplace=True),
nn.Linear(200,200),
nn.ReLU(inplace=True),
nn.Linear(200,10),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.model(x)
mlp = MLP()
optimizer = torch.optim.SGD(mlp.parameters(), lr = 1e-2)
2.4 激活函数与GPU加速
2.4.1 Leaky ReLU

2.4.2 SELU

2.4.3 softplus

2.4.4 GPU加速

边栏推荐
- Ijcai2022 tutorial | dialogue recommendation system
- Young freshmen yearn for more open source | here comes the escape guide from open source to employment!
- What is ci/cd| Achieve faster and better software delivery
- Database -- use of explain
- SSM-使用@Async和创建ThreadPoolTaskExecutor线程池
- Zcmu--5066: dark corridor
- 基于Xilinx的时序分析与约束
- Nacos principle
- Timing analysis and constraints based on Xilinx
- New development of letinar in Korea: single lens 4.55G, light efficiency up to 10%
猜你喜欢

程序员最大的浪漫~

The greatest romance of programmers~

Top level "redis notes", cache avalanche + breakdown + penetration + cluster + distributed lock, Nb

【TiDB】txt文档导入数据库,这样做真的很高效

Timing analysis and constraints based on Xilinx

(转)冒泡排序及优化详解

(PMIC)全、半桥驱动器CSD95481RWJ PDF 规格

Cloud security core technology

Basic operations of unity3d scene production

智能家居行业发展,密切关注边缘计算和小程序容器技术
随机推荐
智能家居行业发展,密切关注边缘计算和小程序容器技术
又一款装机神器
SQL Server 数据库之备份和恢复数据库
Redis cache avalanche, cache penetration, cache breakdown
(PMIC) full and half bridge drive csd95481rwj PDF specification
Unity knowledge points summary (1)
CVPR 2022 | in depth study of batch normalized estimation offset in network
The ref value ‘xxx‘ will likely have changed by the time this effect function runs.If this ref......
工业通讯领域的总线、协议、规范、接口、数据采集与控制系统
CVPR 2022 | 网络中批处理归一化估计偏移的深入研究
证券企业基于容器化 PaaS 平台的 DevOps 规划建设 29 个典型问题总结
【Bluetooth蓝牙开发】八、BLE协议之传输层
New development of letinar in Korea: single lens 4.55G, light efficiency up to 10%
属性基加密仿真及代码实现(CP-ABE)论文:Ciphertext-Policy Attribute-Based Encryption
云安全核心技术
八、QOS队列调度与报文丢弃
How does lazada store make up orders efficiently? (detailed technical explanation of evaluation self-supporting number)
到底为什么不建议使用SELECT * ?
Introduction to blue team: efficiency tools
protobuf 中基础数据类型的读写