当前位置:网站首页>B站刘二大人-数据集及数据加载 Lecture 8
B站刘二大人-数据集及数据加载 Lecture 8
2022-07-06 05:33:00 【宁然也】
系列文章:
y_pred = model(x_data)是 使用所有的数据
想进行批处理,了解几个概念
import torch
from torch.utils.data import Dataset #Dataset抽象子类,需要继承
from torch.utils.data import DataLoader #DataLoade用来加载数据
def getitem(self, index):
def len(self): 返回数据集长度
dataset = DiabetesDataset() 构造DiabetesDataset对象
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=2) 初始化参数
import numpy as np
import torch
import matplotlib.pyplot as plt
# Dataset是抽象类
from torch.utils.data import Dataset
# DataLoader 是抽象类
from torch.utils.data import DataLoader
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
# 输入维度8输出维度6
self.lay1 = torch.nn.Linear(8,6)
self.lay2 = torch.nn.Linear(6,4)
self.lay3 = torch.nn.Linear(4,1)
self.sigmod = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmod(self.lay1(x))
x = self.sigmod(self.lay2(x))
x = self.sigmod(self.lay3(x))
return x
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:,:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
epoch_list = []
loss_list = []
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1-加载数据
inputs, label = data
# 2-forward
y_pred = model(inputs)
loss = criterion(y_pred, label)
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
# 3-反向传播
loss.backward()
# Update
optimizer.step()
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
MNIST数据集导入
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets,transforms
train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=datasets, batch_size=32,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (inouts, target) in enumerate(test_loader):
pass
边栏推荐
- Force buckle 1189 Maximum number of "balloons"
- [detailed explanation of Huawei machine test] statistics of shooting competition results
- [Tang Laoshi] C -- encapsulation: classes and objects
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- 【OSPF 和 ISIS 在多路访问网络中对掩码的要求】
- Please wait while Jenkins is getting ready to work
- Vulhub vulnerability recurrence 67_ Supervisor
- Codeforces Round #804 (Div. 2) Editorial(A-B)
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- Solution of QT TCP packet sticking
猜你喜欢
Pointer classic written test questions
01. Project introduction of blog development project
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Summary of redis basic knowledge points
Codeless June event 2022 codeless Explorer conference will be held soon; AI enhanced codeless tool launched
Yyds dry inventory SSH Remote Connection introduction
Self built DNS server, the client opens the web page slowly, the solution
C Advanced - data storage (Part 1)
RustDesk 搭建一个自己的远程桌面中继服务器
Configuration file converted from Excel to Lua
随机推荐
02. 开发博客项目之数据存储
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Jvxetable用slot植入j-popup
无代码六月大事件|2022无代码探索者大会即将召开;AI增强型无代码工具推出...
UCF(2022暑期团队赛一)
Remember an error in MySQL: the user specified as a definer ('mysql.infoschema '@' localhost ') does not exist
28io stream, byte output stream writes multiple bytes
Summary of redis basic knowledge points
Vulhub vulnerability recurrence 67_ Supervisor
Please wait while Jenkins is getting ready to work
Oracle deletes duplicate data, leaving only one
01. 开发博客项目之项目介绍
改善Jpopup以实现动态控制disable
【torch】|torch. nn. utils. clip_ grad_ norm_
Codeless June event 2022 codeless Explorer conference will be held soon; AI enhanced codeless tool launched
指針經典筆試題
Codeforces Round #804 (Div. 2) Editorial(A-B)
jdbc使用call调用存储过程报错
Excel转换为Lua的配置文件
Promotion hung up! The leader said it wasn't my poor skills