当前位置:网站首页>Pytorch exercise items
Pytorch exercise items
2022-07-03 06:23:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
- Apifix installation
- Es remote cluster configuration and cross cluster search
- UNI-APP中条件注释 实现跨段兼容、导航跳转 和 传参、组件创建使用和生命周期函数
- Read blog type data from mysql, Chinese garbled code - solved
- Time format record
- Print time Hahahahahaha
- 从 Amazon Aurora 迁移数据到 TiDB
- Cesium Click to obtain the longitude and latitude elevation coordinates (3D coordinates) of the model surface
- Kubernetes notes (VII) kuberetes scheduling
猜你喜欢
YOLOV1学习笔记
Fluentd is easy to use. Combined with the rainbow plug-in market, log collection is faster
论文笔记 VSALM 文献综述《A Comprehensive Survey of Visual SLAM Algorithms》
Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
Use abp Zero builds a third-party login module (I): Principles
Creating postgre enterprise database by ArcGIS
Oauth2.0 - using JWT to replace token and JWT content enhancement
ROS+Pytorch的联合使用示例(语义分割)
Cesium Click to obtain the longitude and latitude elevation coordinates (3D coordinates) of the model surface
深入解析kubernetes controller-runtime
随机推荐
phpstudy设置项目可以由局域网的其他电脑可以访问
ssh链接远程服务器 及 远程图形化界面的本地显示
Svn branch management
数值法求解最优控制问题(一)——梯度法
MySQL帶二進制的庫錶導出導入
PHP用ENV获取文件参数的时候拿到的是字符串
【5G NR】UE注册流程
有意思的鼠标指针交互探究
Shell conditional statement
Kubernetes notes (V) configuration management
从 Amazon Aurora 迁移数据到 TiDB
JMeter performance automation test
深入解析kubernetes controller-runtime
Mysql database table export and import with binary
Kubernetes notes (VI) kubernetes storage
Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
Oracle Database Introduction
Reinstalling the system displays "setup is applying system settings" stationary
Kubernetes notes (III) controller
Mysql database