当前位置:网站首页>Deep learning framework pytorch rapid development and actual combat chapter3
Deep learning framework pytorch rapid development and actual combat chapter3
2022-08-02 14:18:00 【weixin_50862344】
报错
问题1

The main reason for this is that there are two in the environmentlibiomp5md.dll文件
①General environment will be placedconda文件夹的envs下
②或者是直接cmd下输入以下代码
conda info --envs
Find the corresponding environment and remove the following one
问题2:data[0]报错
The following codes have been corrected!!!
将data[0]改成item()就可以了
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
(一) 线性回归
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
# Hyper Parameters
input_size = 1
output_size = 1
num_epochs = 1000
learning_rate = 0.001
x_train = np.array([[2.3], [4.4], [3.7], [6.1], [7.3], [2.1],[5.6], [7.7], [8.7], [4.1],
[6.7], [6.1], [7.5], [2.1], [7.2],
[5.6], [5.7], [7.7], [3.1]], dtype=np.float32)
#xtrain生成矩阵数据
y_train = np.array([[3.7], [4.76], [4.], [7.1], [8.6], [3.5],[5.4], [7.6], [7.9], [5.3],
[7.3], [7.5], [8.5], [3.2], [8.7],
[6.4], [6.6], [7.9], [5.3]], dtype=np.float32)
plt.figure()
#draw a scatter plot
plt.scatter(x_train,y_train)
plt.xlabel('x_train')
#x轴名称
plt.ylabel('y_train')
#y轴名称
#显示图片
plt.show()
# Linear Regression Model
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
out = self.linear(x)
return out
model = LinearRegression(input_size, output_size)
# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the Model
for epoch in range(num_epochs):
# Convert numpy array to torch Variable
inputs = Variable(torch.from_numpy(x_train))
targets = Variable(torch.from_numpy(y_train))
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
# Plot the graph
model.eval()
predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
plt.plot(x_train, y_train, 'ro')
plt.plot(x_train, predicted, label='predict')
plt.legend()
plt.show()
总结一下
- 准备数据集
- 定义模型
- 定义损失函数和优化函数
- 开始训练
optimizer.zero_grad() #梯度归零
loss.backward() #反向传播
optimizer.step() #更新参数
Basic steps to draw a scatter plot
- Activate a drawing environment
plt.figure()
To create a graphfigure,Or activate an already existing shapefigure
- 绘制散点图
plt.scatter()函数用于生成一个scatter散点图
- x轴,y轴名称
- 显示图片
模型构建
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
为什么要super(…).init()
In order to be able to inherit the properties of the parent class after the subclass is initialized
损失函数

nn.MSELoss()
Check out this blogger!!!Common loss functions are introduced
优化函数
Variable
(二)逻辑回归
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
# Hyper Parameters
input_size = 784
num_classes = 10
num_epochs = 10
batch_size = 50
learning_rate = 0.001
# MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Model
class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
out = self.linear(x)
return out
model = LogisticRegression(input_size, num_classes)
# Loss and Optimizer
# Softmax is internally computed.
# Set parameters to be updated.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Training the Model
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f'
% (epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.item()))
# Test the Model
correct = 0
total = 0
for images, labels in test_loader:
images = Variable(images.view(-1, 28*28))
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(model.state_dict(), 'model.pkl')
The first download will download the file first and download it to the current file directorydata文件夹中
然后开始训练(大概2-3min)
Let's analyze the code
In fact, the logic is not much different from linear regression,It's nothing more than changing the loss function(replaced by cross entropy)和优化函数(Replaced with stochastic gradient descent)
transforms.ToTensor():接受PIL Image或numpy.ndarray格式,修改通道顺序,Modify the data type and finally normalize(除255)
参考表述
边栏推荐
猜你喜欢
随机推荐
Flask-RESTful请求响应与SQLAlchemy基础
The 2nd China Rust Developers Conference (RustChinaConf 2021~2022) Online Conference Officially Opens Registration
rpm包的卸载与安装[通俗易懂]
第七单元 ORM表关系及操作
Why does a four-byte float represent a wider range than an eight-byte long
第三单元 视图层
logback源码阅读(二)日志打印,自定义appender,encoder,pattern,converter
[ROS] Compiling packages packages encounters slow progress or stuck, use swap
线代:已知一个特征向量快速求另外两个与之正交的特征向量
RKMPP 在FFmpeg上实现硬编解码
瑞吉外卖笔记——第05讲Redis入门
logback源码阅读(一)获取ILoggerFactory、Logger
[ROS](06)ROS通信 —— 话题(Topic)通信
如何解决mysql服务无法启动1069
[ROS](04)package.xml详解
网络剪枝(1)
Sentinel源码(三)slot解析
Supervision strikes again, what about the market outlook?2021-05-22
RKMPP API安装使用总结
ZABBIX配置邮件报警和微信报警




![[ROS] The difference between roscd and cd](/img/a8/a1347568170821e8f186091b93e52a.png)



