当前位置:网站首页>Pytorch linear regression
Pytorch linear regression
2022-07-05 11:42:00 【My abyss, my abyss】
1、 Import table data
filename = "./data.csv"
data = pd.read_csv(filename)
features = data.iloc[:,1:]
labels = data.iloc[:,0]
2、 Turn into Tensor form
''' DataFrame ----> Tensor '''
features = torch.tensor(features.values, dtype=torch.float32)
labels = torch.tensor(labels.values, dtype=torch.float32)
labels = torch.reshape(labels,(-1,1))
3、 Generate iterators , Read data in batches
from torch.utils import data
def load_array(data_arrays , batch_size , is_train = True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset , batch_size , shuffle = is_train)
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(iter(data_iter))
4、 Define neural networks
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.layer1 = nn.Linear(2, 1)
def forward(self, x):
return self.layer1(x)
5、 Training network
model = LinearRegression()
gpu = torch.device('cuda')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
Loss = []
epochs = 1000
def train():
for i in range(epochs):
for X , y in data_iter:
y_hat = model(X) # Calculation model output results
loss = mse_loss(y_hat, y) # Loss function
loss_numpy = loss.detach().numpy()
Loss.append(loss_numpy)
optimizer.zero_grad() # Gradient clear
loss.backward() # Calculate weights
optimizer.step() # Modify weights
print(i, loss.item(), sep='\t')
train() # Training
for parameter in model.parameters():
print(parameter)
plt.plot(Loss)
6 The overall code
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn, optim
from torch.utils import data
class LinearRegression (nn.Module):
def __init__(self, feature_nums):
super (LinearRegression, self).__init__ ()
self.layer1 = nn.Linear (feature_nums, 1)
def forward(self, x):
return self.layer1 (x)
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset (*data_arrays)
return data.DataLoader (dataset, batch_size, shuffle=is_train)
def Linear_Regression_pytorch(filepath, feature_nums, batch_size, learning_rate, epochs):
filename = filepath
data = pd.read_csv (filename)
features = data.iloc[:, 1:]
labels = data.iloc[:, 0]
features = torch.tensor (features.values)
features = torch.tensor (features, dtype=torch.float32)
labels = torch.tensor (labels.values)
labels = torch.tensor (labels, dtype=torch.float32)
labels = torch.reshape (labels, (-1, 1))
batch_size = batch_size
data_iter = load_array ((features, labels), batch_size)
next (iter (data_iter))
model = LinearRegression (feature_nums)
gpu = torch.device ('cuda')
mse_loss = nn.MSELoss ()
optimizer = optim.Adam (model.parameters (), lr=learning_rate)
Loss = []
epochs = epochs
for i in range (epochs):
for X, y in data_iter:
y_hat = model (X) # Calculation model output results
loss = mse_loss (y_hat, y) # Loss function
loss_numpy = loss.detach ().numpy ()
optimizer.zero_grad () # Gradient clear
loss.backward () # Calculate weights
optimizer.step () # Modify weights
print (i, loss.item (), sep='\t')
Loss.append (loss.item ())
for parameter in model.parameters ():
print (parameter)
plt.plot (Loss)
plt.title("Loss")
plt.show ()
if __name__ == "__main__":
Linear_Regression_pytorch (filepath="data.csv",
feature_nums=2,
batch_size=10,
learning_rate=0.05,
epochs=1000
)
边栏推荐
- redis的持久化机制原理
- COMSOL -- 3D casual painting -- sweeping
- 1.php的laravel创建项目
- vscode快捷键
- 程序员内卷和保持行业竞争力
- Implementation of array hash function in PHP
- Solve readobjectstart: expect {or N, but found n, error found in 1 byte of
- [singleshotmultiboxdetector (SSD, single step multi frame target detection)]
- 【PyTorch预训练模型修改、增删特定层】
- COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
猜你喜欢

分类TAB商品流多目标排序模型的演进

Cdga | six principles that data governance has to adhere to

12. (map data) cesium city building map

【yolov3损失函数】

How to make your products as expensive as possible

COMSOL--三维图形的建立

Is it difficult to apply for a job after graduation? "Hundreds of days and tens of millions" online recruitment activities to solve your problems

《增长黑客》阅读笔记

Ziguang zhanrui's first 5g R17 IOT NTN satellite in the world has been measured on the Internet of things

【无标题】
随机推荐
Mysql统计技巧:ON DUPLICATE KEY UPDATE用法
C # implements WinForm DataGridView control to support overlay data binding
MySQL statistical skills: on duplicate key update usage
分类TAB商品流多目标排序模型的演进
解决readObjectStart: expect { or n, but found N, error found in #1 byte of ...||..., bigger context ..
Yolov 5 Target Detection Neural Network - Loss Function Calculation Principle
Acid transaction theory
项目总结笔记系列 wsTax KT Session2 代码分析
15 methods in "understand series after reading" teach you to play with strings
Ziguang zhanrui's first 5g R17 IOT NTN satellite in the world has been measured on the Internet of things
2048 game logic
POJ 3176-Cow Bowling(DP||记忆化搜索)
Ffmpeg calls avformat_ open_ Error -22 returned during input (invalid argument)
redis主从模式
13. (map data) conversion between Baidu coordinate (bd09), national survey of China coordinate (Mars coordinate, gcj02), and WGS84 coordinate system
pytorch-softmax回归
11.(地图数据篇)OSM数据如何下载使用
Cdga | six principles that data governance has to adhere to
Open3D 欧式聚类
SET XACT_ ABORT ON