当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据

def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- [QNX Hypervisor 2.2用户手册]6.3.3 使用共享内存(shmem)虚拟设备
- 移植InfoNES到STM32
- PDK process library installation -csmc
- 28io stream, byte output stream writes multiple bytes
- Self built DNS server, the client opens the web page slowly, the solution
- Codeless June event 2022 codeless Explorer conference will be held soon; AI enhanced codeless tool launched
- 26file filter anonymous inner class and lambda optimization
- 初识CDN
- [QNX hypervisor 2.2 user manual]6.3.3 using shared memory (shmem) virtual devices
- Selective parameters in MATLAB functions
猜你喜欢

nacos-高可用seata之TC搭建(02)
![[mask requirements of OSPF and Isis in multi access network]](/img/7d/1ba80bb906caa9be4bef165ac26d2c.png)
[mask requirements of OSPF and Isis in multi access network]

剑指 Offer II 039. 直方图最大矩形面积

The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower

Vulhub vulnerability recurrence 69_ Tiki Wiki

59. Spiral matrix

Pointer classic written test questions

ByteDance program yuan teaches you how to brush algorithm questions: I'm not afraid of the interviewer tearing the code

ARTS Week 25

Cuda11.1 online installation
随机推荐
【华为机试真题详解】统计射击比赛成绩
59. Spiral matrix
nacos-高可用seata之TC搭建(02)
【经验】UltralSO制作启动盘时报错:磁盘/映像容量太小
C AES encrypts strings
Promotion hung up! The leader said it wasn't my poor skills
28io stream, byte output stream writes multiple bytes
用StopWatch 统计代码耗时
【LeetCode】18、四数之和
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Huawei od computer test question 2
Implementing fuzzy query with dataframe
Force buckle 1189 Maximum number of "balloons"
pix2pix:使用条件对抗网络的图像到图像转换
Codeforces Round #804 (Div. 2) Editorial(A-B)
UCF(2022暑期团队赛一)
【torch】|torch. nn. utils. clip_ grad_ norm_
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
02. 开发博客项目之数据存储
2022 half year summary