当前位置:网站首页>pytorch练习小项目
pytorch练习小项目
2022-07-03 06:15:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- SVN分支管理
- Cesium 点击获三维坐标(经纬度高程)
- Cesium Click to obtain the longitude and latitude elevation coordinates (3D coordinates) of the model surface
- BeanDefinitionRegistryPostProcessor
- 从 Amazon Aurora 迁移数据到 TiDB
- Read blog type data from mysql, Chinese garbled code - solved
- Reinstalling the system displays "setup is applying system settings" stationary
- 技术管理进阶——你了解成长的全貌吗?
- 项目总结--04
- 代码管理工具
猜你喜欢
Synthetic keyword and NBAC mechanism
Migrate data from Mysql to tidb from a small amount of data
Clickhouse learning notes (I): Clickhouse installation, data type, table engine, SQL operation
Redis cluster creation, capacity expansion and capacity reduction
Une exploration intéressante de l'interaction souris - pointeur
Kubesphere - Multi tenant management
Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
JDBC connection database steps
Skywalking8.7 source code analysis (II): Custom agent, service loading, witness component version identification, transform workflow
tabbar的设置
随机推荐
SVN分支管理
Project summary --01 (addition, deletion, modification and query of interfaces; use of multithreading)
The mechanical hard disk is connected to the computer through USB and cannot be displayed
Mysql database
Printer related problem record
Simple understanding of ThreadLocal
Cesium entity (entities) entity deletion method
YOLOV3学习笔记
Jedis source code analysis (II): jediscluster module source code analysis
Push box games C #
Bio, NiO, AIO details
BeanDefinitionRegistryPostProcessor
Oauth2.0 - using JWT to replace token and JWT content enhancement
项目总结--2(Jsoup的基本使用)
23 design models
Kubernetes notes (IV) kubernetes network
Use @data in Lombok to simplify entity class code
Shell conditional statement
有意思的鼠标指针交互探究
Support vector machine for machine learning