当前位置:网站首页>Pytorch exercise items
Pytorch exercise items
2022-07-03 06:23:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- MySQL帶二進制的庫錶導出導入
- Merge and migrate data from small data volume, sub database and sub table Mysql to tidb
- Virtual memory technology sharing
- 【C#/VB.NET】 将PDF转为SVG/Image, SVG/Image转PDF
- CKA certification notes - CKA certification experience post
- 【LeetCode】Day93-两个数组的交集 II
- Phpstudy setting items can be accessed by other computers on the LAN
- When PHP uses env to obtain file parameters, it gets strings
- 项目总结--2(Jsoup的基本使用)
- [set theory] relational closure (relational closure solution | relational graph closure | relational matrix closure | closure operation and relational properties | closure compound operation)
猜你喜欢

Project summary --01 (addition, deletion, modification and query of interfaces; use of multithreading)

ThreadLocal的简单理解

Zhiniu stock project -- 05

Phpstudy setting items can be accessed by other computers on the LAN

Kubernetes notes (VII) kuberetes scheduling

Redis cluster creation, capacity expansion and capacity reduction

Docker advanced learning (container data volume, MySQL installation, dockerfile)

Cesium 点击获三维坐标(经纬度高程)

从小数据量分库分表 MySQL 合并迁移数据到 TiDB

Cesium Click to obtain the longitude and latitude elevation coordinates (3D coordinates) of the model surface
随机推荐
学习笔记 -- k-d tree 和 ikd-Tree 原理及对比
Apifix installation
Advanced technology management - do you know the whole picture of growth?
How to scan when Canon c3120l is a network shared printer
Docker advanced learning (container data volume, MySQL installation, dockerfile)
深入解析kubernetes controller-runtime
arcgis创建postgre企业级数据库
【无标题】8 简易版通讯录
项目总结--2(Jsoup的基本使用)
What's the difference between using the Service Worker Cache API and regular browser cache?
Oracle Database Introduction
Fluentd is easy to use. Combined with the rainbow plug-in market, log collection is faster
Interface test weather API
方差迭代公式推导
【LeetCode】Day93-两个数组的交集 II
Creating postgre enterprise database by ArcGIS
Install VM tools
The mechanical hard disk is connected to the computer through USB and cannot be displayed
Selenium - 改变窗口大小,不同机型呈现的宽高长度会不一样
Oauth2.0 - using JWT to replace token and JWT content enhancement