当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据
def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- SQLite add index
- Pix2pix: image to image conversion using conditional countermeasure networks
- Mongodb basic knowledge summary
- CUDA11.1在线安装
- 【华为机试真题详解】统计射击比赛成绩
- [QNX hypervisor 2.2 user manual]6.3.3 using shared memory (shmem) virtual devices
- Check the useful photo lossless magnification software on Apple computer
- 备忘一下jvxetable的各种数据集获取方法
- SQLite queries the maximum value and returns the whole row of data
- nacos-高可用seata之TC搭建(02)
猜你喜欢
Application Security Series 37: log injection
Vulhub vulnerability recurrence 68_ ThinkPHP
In 2022, we must enter the big factory as soon as possible
05. 博客项目之安全
28io stream, byte output stream writes multiple bytes
Codeforces Round #804 (Div. 2) Editorial(A-B)
05. Security of blog project
自建DNS服务器,客户端打开网页慢,解决办法
Vulhub vulnerability recurrence 73_ Webmin
[untitled]
随机推荐
Nacos - TC Construction of High available seata (02)
Pix2pix: image to image conversion using conditional countermeasure networks
Mongodb basic knowledge summary
Questions d'examen écrit classiques du pointeur
01. 开发博客项目之项目介绍
HAC集群修改管理员用户密码
Selective parameters in MATLAB functions
C Advanced - data storage (Part 1)
Sword finger offer II 039 Maximum rectangular area of histogram
flutter 实现一个有加载动画的按钮(loadingButton)
Yyds dry inventory SSH Remote Connection introduction
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
In 2022, we must enter the big factory as soon as possible
Cuda11.1 online installation
Sliding window problem review
Review of double pointer problems
Steady, 35K, byte business data analysis post
Vulhub vulnerability recurrence 71_ Unomi
Vulhub vulnerability recurrence 67_ Supervisor
图数据库ONgDB Release v-1.0.3