当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据
def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- Easy to understand I2C protocol
- Nacos - TC Construction of High available seata (02)
- Imperial cms7.5 imitation "D9 download station" software application download website source code
- UCF(暑期团队赛二)
- Pix2pix: image to image conversion using conditional countermeasure networks
- Safe mode on Windows
- Quantitative description of ANC noise reduction
- UCF (2022 summer team competition I)
- flutter 实现一个有加载动画的按钮(loadingButton)
- 备忘一下jvxetable的各种数据集获取方法
猜你喜欢
Using stopwatch to count code time
flutter 实现一个有加载动画的按钮(loadingButton)
CUDA11.1在线安装
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
【云原生】3.1 Kubernetes平台安装KubeSpher
Text classification still stays at Bert? The dual contrast learning framework is too strong
01. 开发博客项目之项目介绍
【torch】|torch. nn. utils. clip_ grad_ norm_
Problems encountered in installing mysql8 on MAC
PDK工艺库安装-CSMC
随机推荐
Nacos - TC Construction of High available seata (02)
剑指 Offer II 039. 直方图最大矩形面积
Detailed summary of SQL injection
Improve jpopup to realize dynamic control disable
MySQL advanced learning summary 9: create index, delete index, descending index, and hide index
HAC cluster modifying administrator user password
UCF (summer team competition II)
Selective parameters in MATLAB functions
Promotion hung up! The leader said it wasn't my poor skills
【华为机试真题详解】检查是否存在满足条件的数字组合
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
UCF (2022 summer team competition I)
Easy to understand IIC protocol explanation
How to get list length
Fluent implements a loadingbutton with loading animation
CUDA11.1在线安装
ARTS Week 25
Note the various data set acquisition methods of jvxetable
Installation de la Bibliothèque de processus PDK - csmc
Using stopwatch to count code time