当前位置:网站首页>Pytorch implementation of regression model
Pytorch implementation of regression model
2022-06-12 06:06:00 【Singing under the hedge】
List of articles
pytorch Implement the regression model
One 、 Code
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Build datasets
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)# x data(tensor),shape(100,1)
y = x.pow(2) + 0.2*torch.rand(x.size())# noisy y data(tensor),shape(100,1)
# Building neural networks
# Method 1 :
# class Net(torch.nn.Module):
# def __init__(self,n_feature,n_hidden,n_output):
# super(Net,self).__init__()# Inherit __init__ function
# # Define the form of each layer
# self.hidden = torch.nn.Linear(n_feature,n_hidden)# Hidden layer linear output
# self.output = torch.nn.Linear(n_hidden,n_output)# Output layer linear output
#
# def forward(self,x):
# x = F.relu(self.hidden(x))# Activation function
# x = self.output(x)# Output value
# return x
# net = Net(n_feature=1,n_hidden=10,n_output=1)
# Method 2 :
net = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
# visualization
plt.ion()
plt.show()
# Training network
optimizer = torch.optim.SGD(net.parameters(),lr=0.2)# Stochastic gradient descent , Pass in net All parameters of , Learning rate
loss_func = torch.nn.MSELoss()# Loss function ( Mean square error )
for t in range(100):
pre_y = net(x)# to net Training data , Output predicted value
loss = loss_func(pre_y,y)# Calculate the loss function
optimizer.zero_grad()# Clear the residual update parameter value of the previous step
loss.backward()# Error back propagation
optimizer.step()# Add new parameter update values to net Of parameters On
# mapping
if t%5 == 0 :
plt.cla()
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),pre_y.data.numpy(),'r_',lw=5)
plt.text(0.5,0,'Loss=%.4f'%loss.data.numpy(),fontdict={
'size':20,'color':'red'})
plt.pause(0.1)
Two 、 Realization effect

边栏推荐
- Guns framework multi data source configuration without modifying the configuration file
- Leetcode dynamic programming
- EBook editing and deleting
- C WMI query remote Win32_ Operatingsystem class
- Json-c common APIs
- MySQL master-slave, 6 minutes to master
- JS预解析
- Understanding of distributed transactions
- Jpg format and XML format files are separated into different folders
- Performance optimization metrics and tools
猜你喜欢

Houdini script vex learning

Analysis of memory management mechanism of (UE4 4.26) UE4 uobject

IDEA常用配置

SQLite cross compile dynamic library

A preliminary understanding of function

Unity implements smooth interpolation

(UE4 4.27) customize globalshader

sqlite交叉编译动态库

姿态估计之2D人体姿态估计 - PifPaf:Composite Fields for Human Pose Estimation

为什么联合索引是最左匹配原则?
随机推荐
Unity C script implements AES encryption and decryption
zip 和.items()区别
Json-c common APIs
. Net core and Net framework comparison
EBook upload
Directx11 advanced tutorial cluster based deffered shading
Leetcode 第 80 場雙周賽題解
China embolic coil market trend report, technical innovation and market forecast
English语法_副词_有无ly,意义不同
Directx11 advanced tutorial tiled based deffered shading
Understand Houdini's (heightfield) remap operation
Jpg format and XML format files are separated into different folders
Why don't databases use hash tables?
Leetcode-717. 1-bit and 2-bit characters (O (1) solution)
Liunx Foundation
C WMI query remote Win32_ Operatingsystem class
为什么联合索引是最左匹配原则?
Why do I object so [1.01 to the power of 365 and 0.99 to the power of 365]
MySQL master-slave, 6 minutes to master
Unity3d display FPS script