当前位置:网站首页>深度学习框架pytorch快速开发与实战chapter3
深度学习框架pytorch快速开发与实战chapter3
2022-08-02 14:02:00 【weixin_50862344】
报错
问题1
出现这种情况的主要原因是环境中有两个libiomp5md.dll文件
①一般环境都会放在conda文件夹的envs下
②或者是直接cmd下输入以下代码
conda info --envs
找到相应环境然后去掉下面那个就行
问题2:data[0]报错
以下代码均以修正!!!
将data[0]改成item()就可以了
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
(一) 线性回归
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
# Hyper Parameters
input_size = 1
output_size = 1
num_epochs = 1000
learning_rate = 0.001
x_train = np.array([[2.3], [4.4], [3.7], [6.1], [7.3], [2.1],[5.6], [7.7], [8.7], [4.1],
[6.7], [6.1], [7.5], [2.1], [7.2],
[5.6], [5.7], [7.7], [3.1]], dtype=np.float32)
#xtrain生成矩阵数据
y_train = np.array([[3.7], [4.76], [4.], [7.1], [8.6], [3.5],[5.4], [7.6], [7.9], [5.3],
[7.3], [7.5], [8.5], [3.2], [8.7],
[6.4], [6.6], [7.9], [5.3]], dtype=np.float32)
plt.figure()
#画图散点图
plt.scatter(x_train,y_train)
plt.xlabel('x_train')
#x轴名称
plt.ylabel('y_train')
#y轴名称
#显示图片
plt.show()
# Linear Regression Model
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
out = self.linear(x)
return out
model = LinearRegression(input_size, output_size)
# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the Model
for epoch in range(num_epochs):
# Convert numpy array to torch Variable
inputs = Variable(torch.from_numpy(x_train))
targets = Variable(torch.from_numpy(y_train))
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
# Plot the graph
model.eval()
predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
plt.plot(x_train, y_train, 'ro')
plt.plot(x_train, predicted, label='predict')
plt.legend()
plt.show()
总结一下
- 准备数据集
- 定义模型
- 定义损失函数和优化函数
- 开始训练
optimizer.zero_grad() #梯度归零
loss.backward() #反向传播
optimizer.step() #更新参数
绘制散点图基本步骤
- 激活一个绘图环境
plt.figure()
为了创建一个图形figure,或者激活一个已经存在的图形figure
- 绘制散点图
plt.scatter()函数用于生成一个scatter散点图
- x轴,y轴名称
- 显示图片
模型构建
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
为什么要super(…).init()
为了是子类初始化之后也能继承父类的属性
损失函数
nn.MSELoss()
看这个博主的!!!介绍了常用的损失函数
优化函数
Variable
(二)逻辑回归
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
# Hyper Parameters
input_size = 784
num_classes = 10
num_epochs = 10
batch_size = 50
learning_rate = 0.001
# MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Model
class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
out = self.linear(x)
return out
model = LogisticRegression(input_size, num_classes)
# Loss and Optimizer
# Softmax is internally computed.
# Set parameters to be updated.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Training the Model
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f'
% (epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.item()))
# Test the Model
correct = 0
total = 0
for images, labels in test_loader:
images = Variable(images.view(-1, 28*28))
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(model.state_dict(), 'model.pkl')
第一次下载会先下载文件会下载到当前文件目录下的data文件夹中
然后开始训练(大概2-3min)
浅析一下代码
其实逻辑上和线性回归差不了多少,无非是改了一下损失函数(换成了交叉熵)和优化函数(换成了随机梯度下降)
transforms.ToTensor():接受PIL Image或numpy.ndarray格式,修改通道顺序,修改数据类型最后归一化(除255)
参考表述
边栏推荐
- 网络安全第四次作业
- The world's largest Apache open source foundation is how it works?
- 【Tensorflow】AttributeError: module 'keras.backend' has no attribute 'tf'
- 文件加密软件有哪些?保障你的文件安全
- RowBounds[通俗易懂]
- 网页设计(新手入门)[通俗易懂]
- 政策利空对行情没有长期影响,牛市仍将继续 2021-05-19
- 第三单元 视图层
- You can't accept 60% slump, there is no eligible for gain of 6000% in 2021-05-27
- 不精确微分/不完全微分(Inexact differential/Imperfect differential)
猜你喜欢
苏州大学:从 PostgreSQL 到 TDengine
专访|带着问题去学习,Apache DolphinScheduler 王福政
瑞吉外卖笔记——第10讲Swagger
Mysql's case the when you how to use
shell脚本“画画”
About the development forecast of the market outlook?2021-05-23
The bad policy has no long-term impact on the market, and the bull market will continue 2021-05-19
Sentinel源码(六)ParamFlowSlot热点参数限流
第十四单元 视图集及路由
世界上最大的开源基金会 Apache 是如何运作的?
随机推荐
第十二单元 关联序列化处理
HALCON: 对象(object)从声明(declaration)到结束(finalization)
【Tensorflow】AttributeError: module 'keras.backend' has no attribute 'tf'
What is the difference between web testing and app testing?
drf源码分析与全局捕获异常
logback源码阅读(一)获取ILoggerFactory、Logger
Configure zabbix auto-discovery and auto-registration.
Cloin 控制台乱码
C# 编译错误:Compiler Error CS1044
jwt(json web token)
如何解决mysql服务无法启动1069
RHCE第一天作业
史上最全!47个“数字化转型”常见术语合集,看完秒懂~
C language improvement (3)
Linux:CentOS 7 安装MySQL5.7
此次519暴跌的几点感触 2021-05-21
面试官:可以谈谈乐观锁和悲观锁吗
RKMPP API安装使用总结
Detailed explanation of ORACLE expdp/impdp
第十单元 前后连调