当前位置:网站首页>pytorch-线性回归
pytorch-线性回归
2022-07-05 11:34:00 【我渊啊我渊啊】
1、导入表格数据
filename = "./data.csv"
data = pd.read_csv(filename)
features = data.iloc[:,1:]
labels = data.iloc[:,0]
2、转成Tensor形式
''' DataFrame ----> Tensor '''
features = torch.tensor(features.values, dtype=torch.float32)
labels = torch.tensor(labels.values, dtype=torch.float32)
labels = torch.reshape(labels,(-1,1))
3、生成迭代器,分批次读取数据
from torch.utils import data
def load_array(data_arrays , batch_size , is_train = True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset , batch_size , shuffle = is_train)
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(iter(data_iter))
4、定义神经网络
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.layer1 = nn.Linear(2, 1)
def forward(self, x):
return self.layer1(x)
5、训练网络
model = LinearRegression()
gpu = torch.device('cuda')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
Loss = []
epochs = 1000
def train():
for i in range(epochs):
for X , y in data_iter:
y_hat = model(X) # 计算模型输出结果
loss = mse_loss(y_hat, y) # 损失函数
loss_numpy = loss.detach().numpy()
Loss.append(loss_numpy)
optimizer.zero_grad() # 梯度清零
loss.backward() # 计算权值
optimizer.step() # 修改权值
print(i, loss.item(), sep='\t')
train() # 训练
for parameter in model.parameters():
print(parameter)
plt.plot(Loss)
6 整体代码
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn, optim
from torch.utils import data
class LinearRegression (nn.Module):
def __init__(self, feature_nums):
super (LinearRegression, self).__init__ ()
self.layer1 = nn.Linear (feature_nums, 1)
def forward(self, x):
return self.layer1 (x)
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset (*data_arrays)
return data.DataLoader (dataset, batch_size, shuffle=is_train)
def Linear_Regression_pytorch(filepath, feature_nums, batch_size, learning_rate, epochs):
filename = filepath
data = pd.read_csv (filename)
features = data.iloc[:, 1:]
labels = data.iloc[:, 0]
features = torch.tensor (features.values)
features = torch.tensor (features, dtype=torch.float32)
labels = torch.tensor (labels.values)
labels = torch.tensor (labels, dtype=torch.float32)
labels = torch.reshape (labels, (-1, 1))
batch_size = batch_size
data_iter = load_array ((features, labels), batch_size)
next (iter (data_iter))
model = LinearRegression (feature_nums)
gpu = torch.device ('cuda')
mse_loss = nn.MSELoss ()
optimizer = optim.Adam (model.parameters (), lr=learning_rate)
Loss = []
epochs = epochs
for i in range (epochs):
for X, y in data_iter:
y_hat = model (X) # 计算模型输出结果
loss = mse_loss (y_hat, y) # 损失函数
loss_numpy = loss.detach ().numpy ()
optimizer.zero_grad () # 梯度清零
loss.backward () # 计算权值
optimizer.step () # 修改权值
print (i, loss.item (), sep='\t')
Loss.append (loss.item ())
for parameter in model.parameters ():
print (parameter)
plt.plot (Loss)
plt.title("Loss")
plt.show ()
if __name__ == "__main__":
Linear_Regression_pytorch (filepath="data.csv",
feature_nums=2,
batch_size=10,
learning_rate=0.05,
epochs=1000
)
边栏推荐
- 解决readObjectStart: expect { or n, but found N, error found in #1 byte of ...||..., bigger context ..
- 2048 game logic
- IPv6与IPv4的区别 网信办等三部推进IPv6规模部署
- 871. Minimum Number of Refueling Stops
- Acid transaction theory
- Manage multiple instagram accounts and share anti Association tips
- ibatis的动态sql
- sklearn模型整理
- redis集群中hash tag 使用
- DDoS attack principle, the phenomenon of being attacked by DDoS
猜你喜欢

【Office】Excel中IF函数的8种用法

IPv6与IPv4的区别 网信办等三部推进IPv6规模部署

The ninth Operation Committee meeting of dragon lizard community was successfully held

Idea set the number of open file windows

CDGA|数据治理不得不坚持的六个原则

comsol--三维图形随便画----回转

《增长黑客》阅读笔记

Summary of thread and thread synchronization under window

AutoCAD -- mask command, how to use CAD to locally enlarge drawings

Is it difficult to apply for a job after graduation? "Hundreds of days and tens of millions" online recruitment activities to solve your problems
随机推荐
C # implements WinForm DataGridView control to support overlay data binding
Golang application topic - channel
Solve the grpc connection problem. Dial succeeds with transientfailure
[LeetCode] Wildcard Matching 外卡匹配
COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
go语言学习笔记-分析第一个程序
PHP中Array的hash函数实现
Dynamic SQL of ibatis
Solve the problem of slow access to foreign public static resources
COMSOL -- three-dimensional graphics random drawing -- rotation
redis集群中hash tag 使用
《看完就懂系列》15个方法教你玩转字符串
Project summary notes series wstax kt session2 code analysis
7.2 daily study 4
12.(地图数据篇)cesium城市建筑物贴图
Go language learning notes - analyze the first program
Cdga | six principles that data governance has to adhere to
[leetcode] wild card matching
SLAM 01. Modeling of human recognition Environment & path
c#操作xml文件