当前位置:网站首页>Deep learning framework pytorch rapid development and actual combat chapter3
Deep learning framework pytorch rapid development and actual combat chapter3
2022-08-02 14:18:00 【weixin_50862344】
报错
问题1

The main reason for this is that there are two in the environmentlibiomp5md.dll文件
①General environment will be placedconda文件夹的envs下
②或者是直接cmd下输入以下代码
conda info --envs
Find the corresponding environment and remove the following one
问题2:data[0]报错
The following codes have been corrected!!!
将data[0]改成item()就可以了
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
(一) 线性回归
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
# Hyper Parameters
input_size = 1
output_size = 1
num_epochs = 1000
learning_rate = 0.001
x_train = np.array([[2.3], [4.4], [3.7], [6.1], [7.3], [2.1],[5.6], [7.7], [8.7], [4.1],
[6.7], [6.1], [7.5], [2.1], [7.2],
[5.6], [5.7], [7.7], [3.1]], dtype=np.float32)
#xtrain生成矩阵数据
y_train = np.array([[3.7], [4.76], [4.], [7.1], [8.6], [3.5],[5.4], [7.6], [7.9], [5.3],
[7.3], [7.5], [8.5], [3.2], [8.7],
[6.4], [6.6], [7.9], [5.3]], dtype=np.float32)
plt.figure()
#draw a scatter plot
plt.scatter(x_train,y_train)
plt.xlabel('x_train')
#x轴名称
plt.ylabel('y_train')
#y轴名称
#显示图片
plt.show()
# Linear Regression Model
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
out = self.linear(x)
return out
model = LinearRegression(input_size, output_size)
# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the Model
for epoch in range(num_epochs):
# Convert numpy array to torch Variable
inputs = Variable(torch.from_numpy(x_train))
targets = Variable(torch.from_numpy(y_train))
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
# Plot the graph
model.eval()
predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
plt.plot(x_train, y_train, 'ro')
plt.plot(x_train, predicted, label='predict')
plt.legend()
plt.show()
总结一下
- 准备数据集
- 定义模型
- 定义损失函数和优化函数
- 开始训练
optimizer.zero_grad() #梯度归零
loss.backward() #反向传播
optimizer.step() #更新参数
Basic steps to draw a scatter plot
- Activate a drawing environment
plt.figure()
To create a graphfigure,Or activate an already existing shapefigure
- 绘制散点图
plt.scatter()函数用于生成一个scatter散点图
- x轴,y轴名称
- 显示图片
模型构建
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
为什么要super(…).init()
In order to be able to inherit the properties of the parent class after the subclass is initialized
损失函数

nn.MSELoss()
Check out this blogger!!!Common loss functions are introduced
优化函数
Variable
(二)逻辑回归
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
# Hyper Parameters
input_size = 784
num_classes = 10
num_epochs = 10
batch_size = 50
learning_rate = 0.001
# MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Model
class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
out = self.linear(x)
return out
model = LogisticRegression(input_size, num_classes)
# Loss and Optimizer
# Softmax is internally computed.
# Set parameters to be updated.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Training the Model
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f'
% (epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.item()))
# Test the Model
correct = 0
total = 0
for images, labels in test_loader:
images = Variable(images.view(-1, 28*28))
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(model.state_dict(), 'model.pkl')
The first download will download the file first and download it to the current file directorydata文件夹中
然后开始训练(大概2-3min)
Let's analyze the code
In fact, the logic is not much different from linear regression,It's nothing more than changing the loss function(replaced by cross entropy)和优化函数(Replaced with stochastic gradient descent)
transforms.ToTensor():接受PIL Image或numpy.ndarray格式,修改通道顺序,Modify the data type and finally normalize(除255)
参考表述
边栏推荐
- 目标检测场景SSD-Mobilenetv1-FPN
- Sentinel源码(五)FlowSlot以及限流控制器源码分析
- 浅浅写一下PPOCRLabel的使用及体验
- Chapter6 visualization (don't want to see the version)
- You can't accept 60% slump, there is no eligible for gain of 6000% in 2021-05-27
- Flask项目的完整创建 七牛云与容联云
- Unit 7 ORM table relationships and operations
- [ROS](01)创建ROS工作空间
- [ROS] Compiling packages packages encounters slow progress or stuck, use swap
- redis分布式锁和看门狗的实现
猜你喜欢
随机推荐
不精确微分/不完全微分(Inexact differential/Imperfect differential)
网络安全第五次作业
[ROS] (01) Create ROS workspace
logback源码阅读(一)获取ILoggerFactory、Logger
Flask-RESTful请求响应与SQLAlchemy基础
How to solve mysql service cannot start 1069
Diodes and their applications
【Tensorflow】AttributeError: ‘_TfDeviceCaptureOp‘ object has no attribute ‘_set_device_from_string‘
8583 顺序栈的基本操作
【学习笔记】数位dp
跑yolov5又出啥问题了(1)p,r,map全部为0
期货具体是如何开户的?
LayoutParams的详解
Flask框架
第七单元 ORM表关系及操作
shell脚本“画画”
The bad policy has no long-term impact on the market, and the bull market will continue 2021-05-19
Some impressions of the 519 plummet 2021-05-21
第四单元 路由层
Flask项目的完整创建 七牛云与容联云


![[ROS] (01) Create ROS workspace](/img/2a/11e5023ef6d052d98b4090d2eea017.png)






