当前位置:网站首页>Pytorch实现简单线性回归Demo
Pytorch实现简单线性回归Demo
2022-07-06 09:16:00 【想成为风筝】
Pytorch实现简单线性回归
import numpy as np
x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)
print(x_train.shape)
y_values = [2*i+1 for i in x_values]
y_train = np.array(y_values,dtype=np.float32)
y_train = y_train.reshape(-1,1)
print(y_train.shape)
import torch
import torch.nn as nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class LinearRegressionModel(nn.Module):
def __init__(self,input_dim,output_dim):
super(LinearRegressionModel, self).__init__()
self.Linear = nn.Linear(input_dim,output_dim)
def forward(self,x):
out = self.Linear(x)
return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)
model.to(device)
losses = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
epochs = 1000
for epoch in range(epochs):
epoch += 1
inputs = torch.from_numpy(x_train).to(device)
outputs = torch.from_numpy(y_train).to(device)
optimizer.zero_grad()
out = model(inputs)
loss = losses(out,outputs)
loss.backward()
optimizer.step()
if epoch % 50 == 0:
print('epoch {},loss {}'.format(epoch,loss))
#预测
predicted =model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)
# #保存
# torch.save(model.state_dict(),'model.pkl') #保存模型的参数 w b
# #加载
# model.load_state_dict(torch.load('model.pkl')) #加载
边栏推荐
- 数据库面试常问的一些概念
- Vs2019 desktop app quick start
- {one week summary} take you into the ocean of JS knowledge
- MTCNN人脸检测
- [Bluebridge cup 2020 preliminary] horizontal segmentation
- [BSidesCF_2020]Had_a_bad_day
- [Kerberos] deeply understand the Kerberos ticket life cycle
- error C4996: ‘strcpy‘: This function or variable may be unsafe. Consider using strcpy_ s instead
- C语言读取BMP文件
- XML文件详解:XML是什么、XML配置文件、XML数据文件、XML文件解析教程
猜你喜欢
随机推荐
牛客Novice月赛40
4. Install and deploy spark (spark on Yan mode)
Detailed explanation of nodejs
Dependency in dependencymanagement cannot be downloaded and red is reported
Library function -- (continuous update)
人脸识别 face_recognition
TypeScript
第4阶段 Mysql数据库
分布式节点免密登录
快来走进JVM吧
Pytoch Foundation
MySQL START SLAVE Syntax
Détails du Protocole Internet
SQL time injection
使用LinkedHashMap实现一个LRU算法的缓存
error C4996: ‘strcpy‘: This function or variable may be unsafe. Consider using strcpy_ s instead
电商数据分析--用户行为分析
nodejs连接Mysql
Word排版(小計)
Connexion sans mot de passe du noeud distribué




![[template] KMP string matching](/img/f9/cd8b6f8e2b0335c2ec0a76fc500c9b.jpg)



