当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据

def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- Installation de la Bibliothèque de processus PDK - csmc
- Promotion hung up! The leader said it wasn't my poor skills
- 移植InfoNES到STM32
- 无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
- 02. 开发博客项目之数据存储
- 【torch】|torch.nn.utils.clip_grad_norm_
- Selective parameters in MATLAB functions
- Note the various data set acquisition methods of jvxetable
- Codeforces Round #804 (Div. 2) Editorial(A-B)
- [leetcode daily question] number of enclaves
猜你喜欢

02. 开发博客项目之数据存储

pix2pix:使用条件对抗网络的图像到图像转换

Game push image / table /cv/nlp, multi-threaded start

The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower

Excel转换为Lua的配置文件

59. Spiral matrix

Pointer classic written test questions

Remember an error in MySQL: the user specified as a definer ('mysql.infoschema '@' localhost ') does not exist
![[Tang Laoshi] C -- encapsulation: classes and objects](/img/4e/30d2d4652ea2d4cd5fa7cbbb795863.jpg)
[Tang Laoshi] C -- encapsulation: classes and objects

【OSPF 和 ISIS 在多路访问网络中对掩码的要求】
随机推荐
C# AES对字符串进行加密
59. Spiral matrix
Remember an error in MySQL: the user specified as a definer ('mysql.infoschema '@' localhost ') does not exist
Vulhub vulnerability recurrence 69_ Tiki Wiki
04. 项目博客之日志
Vulhub vulnerability recurrence 67_ Supervisor
February 12 relativelayout
UCF (2022 summer team competition I)
01. Project introduction of blog development project
Can the feelings of Xi'an version of "Coca Cola" and Bingfeng beverage rush for IPO continue?
[force buckle]43 String multiplication
Jvxetable用slot植入j-popup
flutter 实现一个有加载动画的按钮(loadingButton)
2022 half year summary
[detailed explanation of Huawei machine test] check whether there is a digital combination that meets the conditions
Selective parameters in MATLAB functions
SequoiaDB湖仓一体分布式数据库2022.6月刊
巨杉数据库再次亮相金交会,共建数字经济新时代
Game push: image / table /cv/nlp, multi-threaded start!
Game push image / table /cv/nlp, multi-threaded start