当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据
def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- Easy to understand IIC protocol explanation
- Steady, 35K, byte business data analysis post
- 04. 项目博客之日志
- Promotion hung up! The leader said it wasn't my poor skills
- Vulhub vulnerability recurrence 67_ Supervisor
- Select knowledge points of structure
- Jvxetable implant j-popup with slot
- 浅谈镜头滤镜的类型及作用
- Talking about the type and function of lens filter
猜你喜欢
Vulhub vulnerability recurrence 73_ Webmin
Sword finger offer II 039 Maximum rectangular area of histogram
Problems encountered in installing mysql8 on MAC
无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
PDK工艺库安装-CSMC
初识CDN
Deep learning -yolov5 introduction to actual combat click data set training
01. Project introduction of blog development project
Hyperledger Fabric2. Some basic concepts of X (1)
剑指 Offer II 039. 直方图最大矩形面积
随机推荐
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Self built DNS server, the client opens the web page slowly, the solution
UCF (2022 summer team competition I)
【OSPF 和 ISIS 在多路访问网络中对掩码的要求】
js Array 列表 实战使用总结
Note the various data set acquisition methods of jvxetable
C AES encrypts strings
02. 开发博客项目之数据存储
【torch】|torch. nn. utils. clip_ grad_ norm_
Pickle and savez_ Compressed compressed volume comparison
Vulhub vulnerability recurrence 68_ ThinkPHP
【经验】UltralSO制作启动盘时报错:磁盘/映像容量太小
05. 博客项目之安全
In 2022, we must enter the big factory as soon as possible
04. Project blog log
PDK工艺库安装-CSMC
JS array list actual use summary
Huawei od computer test question 2
Game push: image / table /cv/nlp, multi-threaded start!