当前位置:网站首页>C3D模型pytorch源码逐句详析(三)
C3D模型pytorch源码逐句详析(三)
2022-07-25 09:26:00 【zzh1370894823】
3.1 源码解析
train.py 讲解
此代码为C3D模型的训练部分,分为训练前的准备,和训练部分两大部分。
1.训练前的准备
1.1 参数的设置
nEpochs = 101 # Number of epochs for training
resume_epoch = 0 # Default is 0, change if want to resume 即参数改变重头训练
useTest = True # See evolution of the test set when training
nTestInterval = 20 # Run on test set every nTestInterval epochs
snapshot = 25 # Store a model every snapshot epochs
lr = 1e-5 # Learning rate
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) # save_dir_root = '...\\C3D'
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] # exp_name = '...\\C3D'
此部分为一些参数的设置
os.path.dirname(–file–) 获取当前运行脚本的路径
1.2 模型和数据集的载入
model = C3D_model.C3D(num_classes=num_classes, pretrained=False)
train_params = [{
'params': C3D_model.get_1x_lr_params(model), 'lr': lr},
{
'params': C3D_model.get_10x_lr_params(model), 'lr': lr * 10}]
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4) # 优化方法,梯度下降
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,
gamma=0.1)
# 加载数据集
train_dataloader = DataLoader(VideoDataset(dataset=dataset, split='train',clip_len=16), batch_size=2, shuffle=True, num_workers=0)
val_dataloader = DataLoader(VideoDataset(dataset=dataset, split='val', clip_len=16), batch_size=2, num_workers=0)
test_dataloader = DataLoader(VideoDataset(dataset=dataset, split='test', clip_len=16), batch_size=2, num_workers=0)
trainval_loaders = {
'train': train_dataloader, 'val': val_dataloader} # 将train和val组成dict
trainval_sizes = {
x: len(trainval_loaders[x].dataset) for x in ['train', 'val']}
test_size = len(test_dataloader.dataset)
train_params是一个两个元素的list, 每个元素是两个元素的dict
scheduler 的设置:将学习速率按每10个epoch,衰减为0.1倍
将train和val组成dict,方便训练
2.训练部分
for epoch in range(resume_epoch, num_epochs):
for phase in ['train', 'val']:
start_time = timeit.default_timer()
# 清空损失和正确率
running_loss = 0.0
running_corrects = 0.0
if phase == 'train':
scheduler.step() # 训练集更新学习率
model.train()
else:
model.eval()
每个epoch 分为train和val两部分
start_time 记录运行开始时间
scheduler.step() 训练集需要更新学习率
将输入送进模型
for inputs, labels in tqdm(trainval_loaders[phase]):
inputs = Variable(inputs, requires_grad=True).to(device)
labels = Variable(labels).to(device)
optimizer.zero_grad()
if phase == 'train':
outputs = model(inputs)
else:
with torch.no_grad():
outputs = model(inputs)
probs = nn.Softmax(dim=1)(outputs)
preds = torch.max(probs, 1)[1]
loss = criterion(outputs, labels.long()) # 计算损失函数
if phase == 'train':
loss.backward()
optimizer.step() # 训练集更新参数
running_loss += loss.item() * inputs.size(0) # 损失乘batchsize
running_corrects += torch.sum(preds == labels.data) # 预测正确的个数
# 计算一个epoch的损失和准确度
epoch_loss = running_loss / trainval_sizes[phase]
epoch_acc = running_corrects.double() / trainval_sizes[phase]
tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息。
with torch.no_grad(): 验证集禁用梯度计算,它将减少计算所需的内存消耗。
probs 的torch.size 为(2, 101) , 此时取batchsize 为2,共101个动作类别,记录每个动作分类的概率。
preds = torch.max(probs, 1)[1] , 找出其中最大的概率,返回它的下标,即为他们的预测标签
如:preds =tensor[4,32],即预测标签为4和32
最后计算每一个epoch的损失和准确度
写入tensorboard
if phase == 'train':
writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch)
writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch)
else:
writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch)
writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch)
print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format(phase, epoch+1, nEpochs, epoch_loss, epoch_acc))
stop_time = timeit.default_timer() # 记录停止时间
print("Execution time: " + str(stop_time - start_time) + "\n")
保存训练参数
if epoch % save_epoch == (save_epoch - 1):
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'opt_dict': optimizer.state_dict(),
}, os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth.tar'))
print("Save model at {}\n".format(os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth.tar')))
加载测试集
方法和验证集类似,不用计算梯度,更新参数
if useTest and epoch % test_interval == (test_interval - 1):
model.eval()
start_time = timeit.default_timer()
running_loss = 0.0
running_corrects = 0.0
for inputs, labels in tqdm(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = model(inputs)
probs = nn.Softmax(dim=1)(outputs)
preds = torch.max(probs, 1)[1]
loss = criterion(outputs, labels.long())
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / test_size
epoch_acc = running_corrects.double() / test_size
writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch)
print("[test] Epoch: {}/{} Loss: {} Acc: {}".format(epoch+1, nEpochs, epoch_loss, epoch_acc))
stop_time = timeit.default_timer()
print("Execution time: " + str(stop_time - start_time) + "\n")
纯个人思考总结,错误在所难免,欢迎指正,不胜感激。
边栏推荐
- cookie and session
- Qt 6.2的下载和安装
- GCD详解
- 严重 [main] org.apache.catalina.util.LifecycleBase.handleSubClassException 初始化组件
- Is binary cross entropy really suitable for multi label classification?
- JDBC操作数据库详解
- 无线中继采集仪的常见问题
- ISP image signal processing
- 概率论与数理统计 4 Continuous Random Variables and Probability Distributions(连续随机变量与概率分布)(上篇)
- SSM整合(简单的图书管理系统来整合SSM)
猜你喜欢

Mlx90640 infrared thermal imaging sensor temperature measurement module development notes (II)

Detailed explanation of MySQL database

nodejs链接mysql报错:ER_NOT_SUPPORTED_AUTH_MODEError: ER_NOT_SUPPORTED_AUTH_MODE

MLX90640 红外热成像传感器测温模块开发笔记(二)

OC -- Inheritance and polymorphic and pointer

SD/SDIO/EMMC

小程序调起微信支付

TensorFlow raw_ RNN - implement the seq2seq mode to take the output of the previous time as the input of the next time

CentOs安装redis

TM1638 LED数码显示模块ARDUINO驱动代码
随机推荐
Subtotal of rospy odometry sinkhole
OC -- packaging class and processing object
VS无线振弦采集仪蓝牙功能的使用
腾讯云之错误[100007] this env is not enable anonymous login
Introduction to armv8 general timer
概率论与数理统计 3 Discrete Random Variables and Probability Distributions(离散随机变量与概率分布) (下篇)
ESP32连接阿里云MQTT物联网平台
Common methods of nodejs version upgrade or switching
MLX90640 红外热成像仪测温模块开发笔记(一)
1、 Initial mysql, MySQL installation, environment configuration, initialization
Es6详解
入住阿里云MQTT物联网平台
无线中继采集仪的常见问题
四舍五入取近似值
js数字千位分割的常用方法
车辆属性最近一次入库时间初始化生成sql脚本文件
线程池的设计和原理
OC -- Inheritance and polymorphic and pointer
JDBC总结
小程序企业发放红包功能