当前位置:网站首页>Pytorch exercise items
Pytorch exercise items
2022-07-03 06:23:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- Mysql database
- Merge and migrate data from small data volume, sub database and sub table Mysql to tidb
- Chapter 8. MapReduce production experience
- GPS坐标转百度地图坐标的方法
- Method of converting GPS coordinates to Baidu map coordinates
- . Net program configuration file operation (INI, CFG, config)
- Leetcode solution - 02 Add Two Numbers
- 认识弹性盒子flex
- Request weather interface format, automation
- [set theory] equivalence relation (concept of equivalence relation | examples of equivalence relation | equivalence relation and closure)
猜你喜欢

“我为开源打榜狂”第一周榜单公布,160位开发者上榜

. Net program configuration file operation (INI, CFG, config)

Kubernetes notes (III) controller

Reinstalling the system displays "setup is applying system settings" stationary

【5G NR】UE注册流程

Simple solution of small up main lottery in station B

ThreadLocal的简单理解

輕松上手Fluentd,結合 Rainbond 插件市場,日志收集更快捷

使用conda创建自己的深度学习环境

Tabbar settings
随机推荐
远端rostopic的本地rviz调用及显示
认识弹性盒子flex
SQL实现将多行记录合并成一行
Simple solution of small up main lottery in station B
Zhiniu stock project -- 04
【系统设计】邻近服务
Cesium entity(entities) 实体删除方法
方差迭代公式推导
Kubernetes notes (VII) kuberetes scheduling
【C#/VB.NET】 将PDF转为SVG/Image, SVG/Image转PDF
2022 CISP-PTE(三)命令执行
Migrate data from Mysql to tidb from a small amount of data
How to scan when Canon c3120l is a network shared printer
[system design] proximity service
YOLOV3学习笔记
ThreadLocal的简单理解
Kubernetes notes (10) kubernetes Monitoring & debugging
YOLOV1学习笔记
Difference between shortest path and minimum spanning tree
Simple understanding of ThreadLocal