当前位置:网站首页>pytorch-线性回归
pytorch-线性回归
2022-07-05 11:34:00 【我渊啊我渊啊】
1、导入表格数据
filename = "./data.csv"
data = pd.read_csv(filename)
features = data.iloc[:,1:]
labels = data.iloc[:,0]
2、转成Tensor形式
''' DataFrame ----> Tensor '''
features = torch.tensor(features.values, dtype=torch.float32)
labels = torch.tensor(labels.values, dtype=torch.float32)
labels = torch.reshape(labels,(-1,1))
3、生成迭代器,分批次读取数据
from torch.utils import data
def load_array(data_arrays , batch_size , is_train = True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset , batch_size , shuffle = is_train)
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(iter(data_iter))
4、定义神经网络
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.layer1 = nn.Linear(2, 1)
def forward(self, x):
return self.layer1(x)
5、训练网络
model = LinearRegression()
gpu = torch.device('cuda')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
Loss = []
epochs = 1000
def train():
for i in range(epochs):
for X , y in data_iter:
y_hat = model(X) # 计算模型输出结果
loss = mse_loss(y_hat, y) # 损失函数
loss_numpy = loss.detach().numpy()
Loss.append(loss_numpy)
optimizer.zero_grad() # 梯度清零
loss.backward() # 计算权值
optimizer.step() # 修改权值
print(i, loss.item(), sep='\t')
train() # 训练
for parameter in model.parameters():
print(parameter)
plt.plot(Loss)
6 整体代码
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn, optim
from torch.utils import data
class LinearRegression (nn.Module):
def __init__(self, feature_nums):
super (LinearRegression, self).__init__ ()
self.layer1 = nn.Linear (feature_nums, 1)
def forward(self, x):
return self.layer1 (x)
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset (*data_arrays)
return data.DataLoader (dataset, batch_size, shuffle=is_train)
def Linear_Regression_pytorch(filepath, feature_nums, batch_size, learning_rate, epochs):
filename = filepath
data = pd.read_csv (filename)
features = data.iloc[:, 1:]
labels = data.iloc[:, 0]
features = torch.tensor (features.values)
features = torch.tensor (features, dtype=torch.float32)
labels = torch.tensor (labels.values)
labels = torch.tensor (labels, dtype=torch.float32)
labels = torch.reshape (labels, (-1, 1))
batch_size = batch_size
data_iter = load_array ((features, labels), batch_size)
next (iter (data_iter))
model = LinearRegression (feature_nums)
gpu = torch.device ('cuda')
mse_loss = nn.MSELoss ()
optimizer = optim.Adam (model.parameters (), lr=learning_rate)
Loss = []
epochs = epochs
for i in range (epochs):
for X, y in data_iter:
y_hat = model (X) # 计算模型输出结果
loss = mse_loss (y_hat, y) # 损失函数
loss_numpy = loss.detach ().numpy ()
optimizer.zero_grad () # 梯度清零
loss.backward () # 计算权值
optimizer.step () # 修改权值
print (i, loss.item (), sep='\t')
Loss.append (loss.item ())
for parameter in model.parameters ():
print (parameter)
plt.plot (Loss)
plt.title("Loss")
plt.show ()
if __name__ == "__main__":
Linear_Regression_pytorch (filepath="data.csv",
feature_nums=2,
batch_size=10,
learning_rate=0.05,
epochs=1000
)
边栏推荐
- Summary of websites of app stores / APP markets
- Programmers are involved and maintain industry competitiveness
- I used Kaitian platform to build an urban epidemic prevention policy inquiry system [Kaitian apaas battle]
- NFT 交易市场主要使用 ETH 本位进行交易的局面是如何形成的?
- 解决grpc连接问题Dial成功状态为TransientFailure
- 程序员内卷和保持行业竞争力
- shell脚本文件遍历 str转数组 字符串拼接
- redis集群中hash tag 使用
- What does cross-border e-commerce mean? What do you mainly do? What are the business models?
- Mysql统计技巧:ON DUPLICATE KEY UPDATE用法
猜你喜欢

COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics

分类TAB商品流多目标排序模型的演进

COMSOL -- three-dimensional graphics random drawing -- rotation

COMSOL--建立几何模型---二维图形的建立

MySQL giant pit: update updates should be judged with caution by affecting the number of rows!!!

COMSOL -- 3D casual painting -- sweeping

12.(地图数据篇)cesium城市建筑物贴图

11.(地图数据篇)OSM数据如何下载使用
![[crawler] bugs encountered by wasm](/img/29/6782bda4c149b7b2b334238936e211.png)
[crawler] bugs encountered by wasm

MySQL 巨坑:update 更新慎用影响行数做判断!!!
随机推荐
SET XACT_ ABORT ON
以交互方式安装ESXi 6.0
Modulenotfounderror: no module named 'scratch' ultimate solution
汉诺塔问题思路的证明
基于Lucene3.5.0怎样从TokenStream获得Token
Install esxi 6.0 interactively
sklearn模型整理
11. (map data section) how to download and use OSM data
FFmpeg调用avformat_open_input时返回错误 -22(Invalid argument)
ibatis的动态sql
The ninth Operation Committee meeting of dragon lizard community was successfully held
解决readObjectStart: expect { or n, but found N, error found in #1 byte of ...||..., bigger context ..
CDGA|数据治理不得不坚持的六个原则
MySQL 巨坑:update 更新慎用影响行数做判断!!!
Summary of thread and thread synchronization under window
Startup process of uboot:
spark调优(一):从hql转向代码
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
Sklearn model sorting
Cdga | six principles that data governance has to adhere to