当前位置:网站首页>Deep learning framework pytorch rapid development and actual combat chapter3
Deep learning framework pytorch rapid development and actual combat chapter3
2022-08-02 14:18:00 【weixin_50862344】
报错
问题1
The main reason for this is that there are two in the environmentlibiomp5md.dll文件
①General environment will be placedconda文件夹的envs下
②或者是直接cmd下输入以下代码
conda info --envs
Find the corresponding environment and remove the following one
问题2:data[0]报错
The following codes have been corrected!!!
将data[0]改成item()就可以了
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
(一) 线性回归
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
# Hyper Parameters
input_size = 1
output_size = 1
num_epochs = 1000
learning_rate = 0.001
x_train = np.array([[2.3], [4.4], [3.7], [6.1], [7.3], [2.1],[5.6], [7.7], [8.7], [4.1],
[6.7], [6.1], [7.5], [2.1], [7.2],
[5.6], [5.7], [7.7], [3.1]], dtype=np.float32)
#xtrain生成矩阵数据
y_train = np.array([[3.7], [4.76], [4.], [7.1], [8.6], [3.5],[5.4], [7.6], [7.9], [5.3],
[7.3], [7.5], [8.5], [3.2], [8.7],
[6.4], [6.6], [7.9], [5.3]], dtype=np.float32)
plt.figure()
#draw a scatter plot
plt.scatter(x_train,y_train)
plt.xlabel('x_train')
#x轴名称
plt.ylabel('y_train')
#y轴名称
#显示图片
plt.show()
# Linear Regression Model
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
out = self.linear(x)
return out
model = LinearRegression(input_size, output_size)
# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the Model
for epoch in range(num_epochs):
# Convert numpy array to torch Variable
inputs = Variable(torch.from_numpy(x_train))
targets = Variable(torch.from_numpy(y_train))
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if (epoch+1) % 5 == 0:
# 修改后
print ('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.item()))
#修改前:
#% (epoch + 1, num_epochs, loss.data[0]))
# Plot the graph
model.eval()
predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
plt.plot(x_train, y_train, 'ro')
plt.plot(x_train, predicted, label='predict')
plt.legend()
plt.show()
总结一下
- 准备数据集
- 定义模型
- 定义损失函数和优化函数
- 开始训练
optimizer.zero_grad() #梯度归零
loss.backward() #反向传播
optimizer.step() #更新参数
Basic steps to draw a scatter plot
- Activate a drawing environment
plt.figure()
To create a graphfigure,Or activate an already existing shapefigure
- 绘制散点图
plt.scatter()函数用于生成一个scatter散点图
- x轴,y轴名称
- 显示图片
模型构建
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
为什么要super(…).init()
In order to be able to inherit the properties of the parent class after the subclass is initialized
损失函数
nn.MSELoss()
Check out this blogger!!!Common loss functions are introduced
优化函数
Variable
(二)逻辑回归
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
# Hyper Parameters
input_size = 784
num_classes = 10
num_epochs = 10
batch_size = 50
learning_rate = 0.001
# MNIST Dataset (Images and Labels)
train_dataset = dsets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = dsets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# Dataset Loader (Input Pipline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Model
class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
out = self.linear(x)
return out
model = LogisticRegression(input_size, num_classes)
# Loss and Optimizer
# Softmax is internally computed.
# Set parameters to be updated.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Training the Model
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f'
% (epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.item()))
# Test the Model
correct = 0
total = 0
for images, labels in test_loader:
images = Variable(images.view(-1, 28*28))
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(model.state_dict(), 'model.pkl')
The first download will download the file first and download it to the current file directorydata文件夹中
然后开始训练(大概2-3min)
Let's analyze the code
In fact, the logic is not much different from linear regression,It's nothing more than changing the loss function(replaced by cross entropy)和优化函数(Replaced with stochastic gradient descent)
transforms.ToTensor():接受PIL Image或numpy.ndarray格式,修改通道顺序,Modify the data type and finally normalize(除255)
参考表述
边栏推荐
猜你喜欢
随机推荐
Haystack的介绍和使用
Unit 15 Paging, Filtering
网络安全第五次作业
redis分布式锁和看门狗的实现
网络安全第一次作业(2)
浅浅写一下PPOCRLabel的使用及体验
shell脚本“画画”
世界上最大的开源基金会 Apache 是如何运作的?
线代:已知一个特征向量快速求另外两个与之正交的特征向量
Tornado框架路由系统介绍及(IOloop.current().start())启动源码分析
专访|带着问题去学习,Apache DolphinScheduler 王福政
第四单元 路由层
[ROS](04)package.xml详解
【Tensorflow】AttributeError: '_TfDeviceCaptureOp' object has no attribute '_set_device_from_string'
跑yolov5又出啥问题了(1)p,r,map全部为0
【ROS】工控机的软件包不编译
【Tensorflow】AttributeError: ‘_TfDeviceCaptureOp‘ object has no attribute ‘_set_device_from_string‘
【学习笔记】数位dp
Sentinel源码(一)SentinelResourceAspect
What is the difference between web testing and app testing?