当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据

def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- Application Security Series 37: log injection
- Nacos - TC Construction of High available seata (02)
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- 26file filter anonymous inner class and lambda optimization
- [cloud native] 3.1 kubernetes platform installation kubespher
- Promotion hung up! The leader said it wasn't my poor skills
- 应用安全系列之三十七:日志注入
- Unity gets the width and height of Sprite
- 05. Security of blog project
- Mysql高级篇学习总结9:创建索引、删除索引、降序索引、隐藏索引
猜你喜欢

Excel转换为Lua的配置文件

Notes, continuation, escape and other symbols

Easy to understand IIC protocol explanation

【torch】|torch. nn. utils. clip_ grad_ norm_

Unity Vector3. Use and calculation principle of reflect

Vulhub vulnerability recurrence 72_ uWSGI

Talking about the type and function of lens filter

PDK process library installation -csmc
![[leetcode] 18. Sum of four numbers](/img/06/c160b47d756290e5474e4c07e68648.png)
[leetcode] 18. Sum of four numbers

Summary of redis basic knowledge points
随机推荐
Problems encountered in installing mysql8 on MAC
Fluent implements a loadingbutton with loading animation
RustDesk 搭建一个自己的远程桌面中继服务器
Imperial cms7.5 imitation "D9 download station" software application download website source code
[QNX hypervisor 2.2 user manual]6.3.3 using shared memory (shmem) virtual devices
Talking about the type and function of lens filter
【OSPF 和 ISIS 在多路访问网络中对掩码的要求】
Closure, decorator
ByteDance program yuan teaches you how to brush algorithm questions: I'm not afraid of the interviewer tearing the code
Solution of QT TCP packet sticking
How to get list length
[cloud native] 3.1 kubernetes platform installation kubespher
2022 half year summary
Detailed summary of SQL injection
Excel转换为Lua的配置文件
Qt TCP 分包粘包的解决方法
59. Spiral matrix
Configuration file converted from Excel to Lua
Vulhub vulnerability recurrence 69_ Tiki Wiki
Cuda11.1 online installation