当前位置:网站首页>pytorch-线性回归
pytorch-线性回归
2022-07-05 11:34:00 【我渊啊我渊啊】
1、导入表格数据
filename = "./data.csv"
data = pd.read_csv(filename)
features = data.iloc[:,1:]
labels = data.iloc[:,0]
2、转成Tensor形式
''' DataFrame ----> Tensor '''
features = torch.tensor(features.values, dtype=torch.float32)
labels = torch.tensor(labels.values, dtype=torch.float32)
labels = torch.reshape(labels,(-1,1))
3、生成迭代器,分批次读取数据
from torch.utils import data
def load_array(data_arrays , batch_size , is_train = True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset , batch_size , shuffle = is_train)
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(iter(data_iter))
4、定义神经网络
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.layer1 = nn.Linear(2, 1)
def forward(self, x):
return self.layer1(x)
5、训练网络
model = LinearRegression()
gpu = torch.device('cuda')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
Loss = []
epochs = 1000
def train():
for i in range(epochs):
for X , y in data_iter:
y_hat = model(X) # 计算模型输出结果
loss = mse_loss(y_hat, y) # 损失函数
loss_numpy = loss.detach().numpy()
Loss.append(loss_numpy)
optimizer.zero_grad() # 梯度清零
loss.backward() # 计算权值
optimizer.step() # 修改权值
print(i, loss.item(), sep='\t')
train() # 训练
for parameter in model.parameters():
print(parameter)
plt.plot(Loss)
6 整体代码
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn, optim
from torch.utils import data
class LinearRegression (nn.Module):
def __init__(self, feature_nums):
super (LinearRegression, self).__init__ ()
self.layer1 = nn.Linear (feature_nums, 1)
def forward(self, x):
return self.layer1 (x)
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset (*data_arrays)
return data.DataLoader (dataset, batch_size, shuffle=is_train)
def Linear_Regression_pytorch(filepath, feature_nums, batch_size, learning_rate, epochs):
filename = filepath
data = pd.read_csv (filename)
features = data.iloc[:, 1:]
labels = data.iloc[:, 0]
features = torch.tensor (features.values)
features = torch.tensor (features, dtype=torch.float32)
labels = torch.tensor (labels.values)
labels = torch.tensor (labels, dtype=torch.float32)
labels = torch.reshape (labels, (-1, 1))
batch_size = batch_size
data_iter = load_array ((features, labels), batch_size)
next (iter (data_iter))
model = LinearRegression (feature_nums)
gpu = torch.device ('cuda')
mse_loss = nn.MSELoss ()
optimizer = optim.Adam (model.parameters (), lr=learning_rate)
Loss = []
epochs = epochs
for i in range (epochs):
for X, y in data_iter:
y_hat = model (X) # 计算模型输出结果
loss = mse_loss (y_hat, y) # 损失函数
loss_numpy = loss.detach ().numpy ()
optimizer.zero_grad () # 梯度清零
loss.backward () # 计算权值
optimizer.step () # 修改权值
print (i, loss.item (), sep='\t')
Loss.append (loss.item ())
for parameter in model.parameters ():
print (parameter)
plt.plot (Loss)
plt.title("Loss")
plt.show ()
if __name__ == "__main__":
Linear_Regression_pytorch (filepath="data.csv",
feature_nums=2,
batch_size=10,
learning_rate=0.05,
epochs=1000
)
边栏推荐
- Solve readobjectstart: expect {or N, but found n, error found in 1 byte of
- PHP中Array的hash函数实现
- Harbor镜像仓库搭建
- 【爬虫】charles unknown错误
- Differences between IPv6 and IPv4 three departments including the office of network information technology promote IPv6 scale deployment
- 871. Minimum Number of Refueling Stops
- 跨境电商是啥意思?主要是做什么的?业务模式有哪些?
- 以交互方式安装ESXi 6.0
- [LeetCode] Wildcard Matching 外卡匹配
- C language current savings account management system
猜你喜欢
一次生产环境redis内存占用居高不下问题排查
The ninth Operation Committee meeting of dragon lizard community was successfully held
11. (map data section) how to download and use OSM data
【无标题】
Idea set the number of open file windows
Redis集群(主从)脑裂及解决方案
7 大主题、9 位技术大咖!龙蜥大讲堂7月硬核直播预告抢先看,明天见
NFT 交易市场主要使用 ETH 本位进行交易的局面是如何形成的?
COMSOL -- establishment of 3D graphics
How did the situation that NFT trading market mainly uses eth standard for trading come into being?
随机推荐
程序员内卷和保持行业竞争力
Go language learning notes - analyze the first program
Idea set the number of open file windows
无密码身份验证如何保障用户隐私安全?
Open3D 网格(曲面)赋色
Programmers are involved and maintain industry competitiveness
Web API configuration custom route
How can edge computing be combined with the Internet of things?
爬虫(9) - Scrapy框架(1) | Scrapy 异步网络爬虫框架
C#实现WinForm DataGridView控件支持叠加数据绑定
An error is reported in the process of using gbase 8C database: 80000305, host IPS long to different cluster. How to solve it?
go语言学习笔记-分析第一个程序
石油化工企业安全生产智能化管控系统平台建设思考和建议
7 themes and 9 technology masters! Dragon Dragon lecture hall hard core live broadcast preview in July, see you tomorrow
Open3D 欧式聚类
sklearn模型整理
分类TAB商品流多目标排序模型的演进
redis主从模式
Spark Tuning (I): from HQL to code
Golang application topic - channel