当前位置:网站首页>pytorch-线性回归
pytorch-线性回归
2022-07-05 11:34:00 【我渊啊我渊啊】
1、导入表格数据
filename = "./data.csv"
data = pd.read_csv(filename)
features = data.iloc[:,1:]
labels = data.iloc[:,0]
2、转成Tensor形式
''' DataFrame ----> Tensor '''
features = torch.tensor(features.values, dtype=torch.float32)
labels = torch.tensor(labels.values, dtype=torch.float32)
labels = torch.reshape(labels,(-1,1))
3、生成迭代器,分批次读取数据
from torch.utils import data
def load_array(data_arrays , batch_size , is_train = True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset , batch_size , shuffle = is_train)
batch_size = 10
data_iter = load_array((features,labels),batch_size)
next(iter(data_iter))
4、定义神经网络
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.layer1 = nn.Linear(2, 1)
def forward(self, x):
return self.layer1(x)
5、训练网络
model = LinearRegression()
gpu = torch.device('cuda')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
Loss = []
epochs = 1000
def train():
for i in range(epochs):
for X , y in data_iter:
y_hat = model(X) # 计算模型输出结果
loss = mse_loss(y_hat, y) # 损失函数
loss_numpy = loss.detach().numpy()
Loss.append(loss_numpy)
optimizer.zero_grad() # 梯度清零
loss.backward() # 计算权值
optimizer.step() # 修改权值
print(i, loss.item(), sep='\t')
train() # 训练
for parameter in model.parameters():
print(parameter)
plt.plot(Loss)
6 整体代码
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn, optim
from torch.utils import data
class LinearRegression (nn.Module):
def __init__(self, feature_nums):
super (LinearRegression, self).__init__ ()
self.layer1 = nn.Linear (feature_nums, 1)
def forward(self, x):
return self.layer1 (x)
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset (*data_arrays)
return data.DataLoader (dataset, batch_size, shuffle=is_train)
def Linear_Regression_pytorch(filepath, feature_nums, batch_size, learning_rate, epochs):
filename = filepath
data = pd.read_csv (filename)
features = data.iloc[:, 1:]
labels = data.iloc[:, 0]
features = torch.tensor (features.values)
features = torch.tensor (features, dtype=torch.float32)
labels = torch.tensor (labels.values)
labels = torch.tensor (labels, dtype=torch.float32)
labels = torch.reshape (labels, (-1, 1))
batch_size = batch_size
data_iter = load_array ((features, labels), batch_size)
next (iter (data_iter))
model = LinearRegression (feature_nums)
gpu = torch.device ('cuda')
mse_loss = nn.MSELoss ()
optimizer = optim.Adam (model.parameters (), lr=learning_rate)
Loss = []
epochs = epochs
for i in range (epochs):
for X, y in data_iter:
y_hat = model (X) # 计算模型输出结果
loss = mse_loss (y_hat, y) # 损失函数
loss_numpy = loss.detach ().numpy ()
optimizer.zero_grad () # 梯度清零
loss.backward () # 计算权值
optimizer.step () # 修改权值
print (i, loss.item (), sep='\t')
Loss.append (loss.item ())
for parameter in model.parameters ():
print (parameter)
plt.plot (Loss)
plt.title("Loss")
plt.show ()
if __name__ == "__main__":
Linear_Regression_pytorch (filepath="data.csv",
feature_nums=2,
batch_size=10,
learning_rate=0.05,
epochs=1000
)
边栏推荐
- AutoCAD -- mask command, how to use CAD to locally enlarge drawings
- 边缘计算如何与物联网结合在一起?
- redis主从模式
- Implementation of array hash function in PHP
- Acid transaction theory
- Solve the problem of slow access to foreign public static resources
- [LeetCode] Wildcard Matching 外卡匹配
- sklearn模型整理
- Redis如何实现多可用区?
- -26374 and -26377 errors during coneroller execution
猜你喜欢
中非 钻石副石怎么镶嵌,才能既安全又好看?
IPv6与IPv4的区别 网信办等三部推进IPv6规模部署
COMSOL -- 3D casual painting -- sweeping
COMSOL -- three-dimensional graphics random drawing -- rotation
13.(地图数据篇)百度坐标(BD09)、国测局坐标(火星坐标,GCJ02)、和WGS84坐标系之间的转换
11. (map data section) how to download and use OSM data
12. (map data) cesium city building map
龙蜥社区第九次运营委员会会议顺利召开
Idea set the number of open file windows
Redis集群的重定向
随机推荐
COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
go语言学习笔记-初识Go语言
Cron表达式(七子表达式)
阻止浏览器后退操作
POJ 3176 cow bowling (DP | memory search)
comsol--三维图形随便画----回转
基于Lucene3.5.0怎样从TokenStream获得Token
紫光展锐全球首个5G R17 IoT NTN卫星物联网上星实测完成
How to get a token from tokenstream based on Lucene 3.5.0
Advanced technology management - what is the physical, mental and mental strength of managers
Spark Tuning (I): from HQL to code
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
边缘计算如何与物联网结合在一起?
Unity xlua monoproxy mono proxy class
高校毕业求职难?“百日千万”网络招聘活动解决你的难题
MySQL 巨坑:update 更新慎用影响行数做判断!!!
SET XACT_ABORT ON
分类TAB商品流多目标排序模型的演进
管理多个Instagram帐户防关联小技巧大分享
技术管理进阶——什么是管理者之体力、脑力、心力