当前位置:网站首页>Pytoch implements simple linear regression demo
Pytoch implements simple linear regression demo
2022-07-06 12:00:00 【Want to be a kite】
Pytorch Implement simple linear regression
import numpy as np
x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)
print(x_train.shape)
y_values = [2*i+1 for i in x_values]
y_train = np.array(y_values,dtype=np.float32)
y_train = y_train.reshape(-1,1)
print(y_train.shape)
import torch
import torch.nn as nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class LinearRegressionModel(nn.Module):
def __init__(self,input_dim,output_dim):
super(LinearRegressionModel, self).__init__()
self.Linear = nn.Linear(input_dim,output_dim)
def forward(self,x):
out = self.Linear(x)
return out
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim,output_dim)
model.to(device)
losses = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
epochs = 1000
for epoch in range(epochs):
epoch += 1
inputs = torch.from_numpy(x_train).to(device)
outputs = torch.from_numpy(y_train).to(device)
optimizer.zero_grad()
out = model(inputs)
loss = losses(out,outputs)
loss.backward()
optimizer.step()
if epoch % 50 == 0:
print('epoch {},loss {}'.format(epoch,loss))
# forecast
predicted =model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
print(predicted)
# # preservation
# torch.save(model.state_dict(),'model.pkl') # Save the parameters of the model w b
# # load
# model.load_state_dict(torch.load('model.pkl')) # load
边栏推荐
- sklearn之feature_extraction.text.CountVectorizer / TfidVectorizer
- [mrctf2020] dolls
- Common regular expression collation
- Comparison of solutions of Qualcomm & MTK & Kirin mobile platform USB3.0
- 分布式节点免密登录
- MySQL主从复制的原理以及实现
- MySQL数据库面试题
- Implementation scheme of distributed transaction
- Detailed explanation of express framework
- Detailed explanation of 5g working principle (explanation & illustration)
猜你喜欢

STM32型号与Contex m对应关系

arduino UNO R3的寄存器写法(1)-----引脚电平状态变化

Wangeditor rich text reference and table usage

【yarn】CDP集群 Yarn配置capacity调度器批量分配

锂电池基础知识

小天才电话手表 Z3工作原理

Basic knowledge of lithium battery

Implementation scheme of distributed transaction

4. Install and deploy spark (spark on Yan mode)

Comparaison des solutions pour la plate - forme mobile Qualcomm & MTK & Kirin USB 3.0
随机推荐
Vert. x: A simple login access demo (simple use of router)
高通&MTK&麒麟 手機平臺USB3.0方案對比
Word排版(小計)
[Flink] cdh/cdp Flink on Yan log configuration
优先级反转与死锁
2020 WANGDING cup_ Rosefinch formation_ Web_ nmap
【yarn】CDP集群 Yarn配置capacity调度器批量分配
分布式節點免密登錄
Contiki源码+原理+功能+编程+移植+驱动+网络(转)
4. Install and deploy spark (spark on Yan mode)
[mrctf2020] dolls
JS object and event learning notes
Pytorch实现简单线性回归Demo
List and set
[NPUCTF2020]ReadlezPHP
【yarn】Yarn container 日志清理
Comparaison des solutions pour la plate - forme mobile Qualcomm & MTK & Kirin USB 3.0
OPPO VOOC快充电路和协议
Basic use of pytest
[Presto] Presto parameter configuration optimization