当前位置:网站首页>Pytorch exercise items
Pytorch exercise items
2022-07-03 06:23:00 【fksfdh】
config.py
class Hyperparameter:
# ################################################
# data
# ################################################
device = 'cpu'
trainset_path = './data/train.txt'
testset_path = './data/test.txt'
seed = 1234
# ################################################
# model
# ################################################
in_features = 4
out_dim = 2
layer_list = [in_features,64,128,64,out_dim]
# ################################################
# train
# ################################################
init_lr = 1e-3
batch_size = 64
epochs = 100
verbose_step = 10
save_step = 500
HP = Hyperparameter()
dataset_banknote.py
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from config import HP
class BanknoteDataset(Dataset):
def __init__(self,dataset_path):
self.dataset = np.loadtxt(dataset_path,delimiter=',')
def __getitem__(self, idx):
item = self.dataset[idx]
x,y = item[:HP.in_features],item[HP.in_features:]
return torch.Tensor(x).float().to(HP.device),torch.Tensor(y).squeeze().long().to(HP.device)
def __len__(self):
return self.dataset.shape[0]
model.py
from torch import nn
import torch.nn.functional as F
from config import HP
import torch
class BanknoteClassificationModel(nn.Module):
def __init__(self):
super(BanknoteClassificationModel, self).__init__()
self.linear_layer = nn.ModuleList([
nn.Linear(in_features=in_dim,out_features=out_dim)
for in_dim,out_dim in zip(HP.layer_list[:-1],HP.layer_list[1:])
])
def forward(self,input_x):
for layer in self.linear_layer:
input_x = layer(input_x)
input_x = F.relu(input_x)
return input_x
if __name__ == '__main__':
data = torch.randn((64,4))
model = BanknoteClassificationModel()
res = model(data)
print(res.size())
trainer.py
import random
import os
from argparse import ArgumentParser
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from config import HP
from model import BanknoteClassificationModel
from dataset_banknote import BanknoteDataset
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
def evaluate(model_,test_loader,crit):
model_.eval()
sum_loss = 0.
with torch.no_grad():
for batch in test_loader:
x,y = batch
pred = model_(x)
loss = crit(pred,y)
sum_loss += loss
model_.train()
return sum_loss / len(test_loader)
def save_checkpoint(model_,epoch,optm,checkpoint_path):
save_dict = {
'model_state_dict':model_.state_dict(),
'optimizer_state_dict': optm.state_dict(),
'epoch':epoch
}
torch.save(save_dict,checkpoint_path)
def train():
parser = ArgumentParser(description="Model Training")
parser.add_argument(
'--c',
default=None,
type=str,
help='train from scratch or resume training'
)
args = parser.parse_args()
#data
train_set = BanknoteDataset(HP.trainset_path)
train_loader = DataLoader(train_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
test_set = BanknoteDataset(HP.testset_path)
test_loader = DataLoader(test_set,batch_size=HP.batch_size,shuffle=True,drop_last=True)
#model
model = BanknoteClassificationModel()
#loss
criterion = nn.CrossEntropyLoss()
#optm
optm = optim.Adam(model.parameters(),lr=HP.init_lr)
start_epoch,step = 0,0
if args.c:
checkpoint = torch.load(args.c)
model.load_state_dict(checkpoint["model_state_dict"])
optm.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print("Resume from %s" % args.c)
else:
print("Training from scratch !")
model.train()
for epoch in range(start_epoch,HP.epochs):
print("Start_Epoch:%d,Steps:%d"%(epoch,len(train_loader)/HP.batch_size))
for batch in train_loader:
x ,y = batch
optm.zero_grad()
pred = model(x)
loss = criterion(pred,y)
loss.backward()
optm.step()
if not step % HP.verbose_step:
eval_loss = evaluate(model,test_loader,criterion)
if not step % HP.save_step:
model_path = "model_%d_%d.pth" %(epoch,step)
save_checkpoint(model,epoch,optm,os.path.join("model_save",model_path))
step += 1
print("Epoch:[%d/%d],step:%d,train_loss:%.5f,test_loss:%.5f"%(epoch,HP.epochs,step,loss.item(),eval_loss))
if __name__ == '__main__':
train()
边栏推荐
- Print time Hahahahahaha
- Click cesium to obtain three-dimensional coordinates (longitude, latitude and elevation)
- 远端rostopic的本地rviz调用及显示
- Nacos service installation
- Kubesphere - Multi tenant management
- Virtual memory technology sharing
- The mechanical hard disk is connected to the computer through USB and cannot be displayed
- Request weather interface format, automation
- Use @data in Lombok to simplify entity class code
- SVN分支管理
猜你喜欢

Push box games C #

Fluentd is easy to use. Combined with the rainbow plug-in market, log collection is faster

. Net program configuration file operation (INI, CFG, config)

Use selenium to climb the annual box office of Yien
![[set theory] relational closure (relational closure solution | relational graph closure | relational matrix closure | closure operation and relational properties | closure compound operation)](/img/a4/00aca72b268f77fe4fb24ac06289f5.jpg)
[set theory] relational closure (relational closure solution | relational graph closure | relational matrix closure | closure operation and relational properties | closure compound operation)

Selenium ide installation recording and local project maintenance

YOLOV3学习笔记

Reinstalling the system displays "setup is applying system settings" stationary

Kubernetes notes (IX) kubernetes application encapsulation and expansion

项目总结--01(接口的增删改查;多线程的使用)
随机推荐
冒泡排序的简单理解
Mysql database
Important knowledge points of redis
“我为开源打榜狂”第一周榜单公布,160位开发者上榜
数值法求解最优控制问题(一)——梯度法
JMeter performance automation test
Exportation et importation de tables de bibliothèque avec binaires MySQL
项目总结--2(Jsoup的基本使用)
Cesium Click to obtain the longitude and latitude elevation coordinates (3D coordinates) of the model surface
IE browser flash back, automatically open edge browser
Kubernetes notes (VII) kuberetes scheduling
Shell conditional statement
Use selenium to climb the annual box office of Yien
PHP用ENV获取文件参数的时候拿到的是字符串
arcgis创建postgre企业级数据库
ThreadLocal的简单理解
Request weather interface format, automation
Solve the problem that Anaconda environment cannot be accessed in PowerShell
Print time Hahahahahaha
【C#/VB.NET】 将PDF转为SVG/Image, SVG/Image转PDF