当前位置:网站首页>pytorch练习小项目
pytorch练习小项目
2022-07-03 06:15:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- Time format record
- Cesium 点击获三维坐标(经纬度高程)
- Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
- Oauth2.0 - Introduction and use and explanation of authorization code mode
- Kubesphere - build Nacos cluster
- Solve the problem that Anaconda environment cannot be accessed in PowerShell
- SVN分支管理
- Oauth2.0 - user defined mode authorization - SMS verification code login
- Kubernetes notes (VII) kuberetes scheduling
- Beandefinitionregistrypostprocessor
猜你喜欢
![[set theory] relational closure (relational closure solution | relational graph closure | relational matrix closure | closure operation and relational properties | closure compound operation)](/img/a4/00aca72b268f77fe4fb24ac06289f5.jpg)
[set theory] relational closure (relational closure solution | relational graph closure | relational matrix closure | closure operation and relational properties | closure compound operation)

Zhiniu stock project -- 05

智牛股--03

Method of converting GPS coordinates to Baidu map coordinates

Disruptor learning notes: basic use, core concepts and principles

Project summary --2 (basic use of jsup)

Oauth2.0 - using JWT to replace token and JWT content enhancement

YOLOV1学习笔记

JMeter linked database

Reinstalling the system displays "setup is applying system settings" stationary
随机推荐
Difference between shortest path and minimum spanning tree
Kubesphere - build Nacos cluster
Kubernetes notes (VI) kubernetes storage
Phpstudy setting items can be accessed by other computers on the LAN
Analysis of Clickhouse mergetree principle
Install VM tools
Scripy learning
How to scan when Canon c3120l is a network shared printer
Kubesphere - Multi tenant management
[set theory] relational closure (relational closure solution | relational graph closure | relational matrix closure | closure operation and relational properties | closure compound operation)
BeanDefinitionRegistryPostProcessor
剖析虚幻渲染体系(16)- 图形驱动的秘密
致即将毕业大学生的一封信
POI dealing with Excel learning
Mysql database binlog log enable record
Beandefinitionregistrypostprocessor
Jackson: what if there is a lack of property- Jackson: What happens if a property is missing?
Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
Cesium entity (entities) entity deletion method
Selenium ide installation recording and local project maintenance