当前位置:网站首页>Pytorch exercise items
Pytorch exercise items
2022-07-03 06:23:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- Kubesphere - build Nacos cluster
- UNI-APP中条件注释 实现跨段兼容、导航跳转 和 传参、组件创建使用和生命周期函数
- 认识弹性盒子flex
- ROS+Pytorch的联合使用示例(语义分割)
- 从 Amazon Aurora 迁移数据到 TiDB
- 项目总结--2(Jsoup的基本使用)
- Request weather interface format, automation
- The win7 computer can't start. Turn the CPU fan and stop it
- Kubesphere - Multi tenant management
- Cesium entity (entities) entity deletion method
猜你喜欢
After the Chrome browser is updated, lodop printing cannot be called
技术管理进阶——你了解成长的全貌吗?
有意思的鼠標指針交互探究
Fluentd facile à utiliser avec le marché des plug - ins rainbond pour une collecte de journaux plus rapide
Project summary --04
Scripy learning
CKA certification notes - CKA certification experience post
使用 Abp.Zero 搭建第三方登录模块(一):原理篇
Selenium - 改变窗口大小,不同机型呈现的宽高长度会不一样
Kubesphere - build Nacos cluster
随机推荐
Nacos service installation
In depth learning
Pdf files can only print out the first page
[system design] proximity service
The mechanical hard disk is connected to the computer through USB and cannot be displayed
【5G NR】UE注册流程
conda和pip的区别
Project summary --01 (addition, deletion, modification and query of interfaces; use of multithreading)
Cesium Click to obtain the longitude and latitude elevation coordinates (3D coordinates) of the model surface
认识弹性盒子flex
Shell conditional statement
Kubernetes notes (II) pod usage notes
Leetcode solution - 01 Two Sum
JMeter performance automation test
Support vector machine for machine learning
Openresty best practices
After the Chrome browser is updated, lodop printing cannot be called
Local rviz call and display of remote rostopic
Docker advanced learning (container data volume, MySQL installation, dockerfile)
10万奖金被瓜分,快来认识这位上榜者里的“乘风破浪的姐姐”