当前位置:网站首页>C3d model pytorch source code sentence by sentence analysis (III)
C3d model pytorch source code sentence by sentence analysis (III)
2022-07-25 10:46:00 【zzh1370894823】
3.1 The source code parsing
train.py Explain
This code is C3D The training part of the model , It is divided into preparation before training , And training .
1. Preparation before training
1.1 Parameter settings
nEpochs = 101 # Number of epochs for training
resume_epoch = 0 # Default is 0, change if want to resume That is, change the parameters and start training again
useTest = True # See evolution of the test set when training
nTestInterval = 20 # Run on test set every nTestInterval epochs
snapshot = 25 # Store a model every snapshot epochs
lr = 1e-5 # Learning rate
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) # save_dir_root = '...\\C3D'
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] # exp_name = '...\\C3D'
This section is about the setting of some parameters
os.path.dirname(–file–) Get the path of the currently running script
1.2 Loading of models and datasets
model = C3D_model.C3D(num_classes=num_classes, pretrained=False)
train_params = [{
'params': C3D_model.get_1x_lr_params(model), 'lr': lr},
{
'params': C3D_model.get_10x_lr_params(model), 'lr': lr * 10}]
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4) # An optimization method , gradient descent
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,
gamma=0.1)
# Load data set
train_dataloader = DataLoader(VideoDataset(dataset=dataset, split='train',clip_len=16), batch_size=2, shuffle=True, num_workers=0)
val_dataloader = DataLoader(VideoDataset(dataset=dataset, split='val', clip_len=16), batch_size=2, num_workers=0)
test_dataloader = DataLoader(VideoDataset(dataset=dataset, split='test', clip_len=16), batch_size=2, num_workers=0)
trainval_loaders = {
'train': train_dataloader, 'val': val_dataloader} # take train and val form dict
trainval_sizes = {
x: len(trainval_loaders[x].dataset) for x in ['train', 'val']}
test_size = len(test_dataloader.dataset)
train_params It's a two element list, Each element is two elements dict
scheduler Set up : Set the learning rate as per 10 individual epoch, The attenuation is 0.1 times
take train and val form dict, Easy to train
2. Training part
for epoch in range(resume_epoch, num_epochs):
for phase in ['train', 'val']:
start_time = timeit.default_timer()
# Clear losses and accuracy
running_loss = 0.0
running_corrects = 0.0
if phase == 'train':
scheduler.step() # Training set update learning rate
model.train()
else:
model.eval()
Every epoch It is divided into train and val Two parts
start_time Record run start time
scheduler.step() The training set needs to update the learning rate
Send input into the model
for inputs, labels in tqdm(trainval_loaders[phase]):
inputs = Variable(inputs, requires_grad=True).to(device)
labels = Variable(labels).to(device)
optimizer.zero_grad()
if phase == 'train':
outputs = model(inputs)
else:
with torch.no_grad():
outputs = model(inputs)
probs = nn.Softmax(dim=1)(outputs)
preds = torch.max(probs, 1)[1]
loss = criterion(outputs, labels.long()) # Calculate the loss function
if phase == 'train':
loss.backward()
optimizer.step() # Training set update parameters
running_loss += loss.item() * inputs.size(0) # Loss multiplication batchsize
running_corrects += torch.sum(preds == labels.data) # Predict the right number
# Calculate a epoch Loss and accuracy
epoch_loss = running_loss / trainval_sizes[phase]
epoch_acc = running_corrects.double() / trainval_sizes[phase]
tqdm It's a fast one , Extensible Python Progress bar , Can be in Python Add a progress prompt to the long loop .
with torch.no_grad(): Validation set disables gradient calculation , It will reduce the memory consumption required for calculation .
probs Of torch.size by (2, 101) , At this time take batchsize by 2, common 101 Action categories , Record the probability of each action classification .
preds = torch.max(probs, 1)[1] , Find out the maximum probability , Return its subscript , That is their prediction label
Such as :preds =tensor[4,32], That is, the prediction label is 4 and 32
Finally, calculate each epoch Loss and accuracy
write in tensorboard
if phase == 'train':
writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch)
writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch)
else:
writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch)
writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch)
print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format(phase, epoch+1, nEpochs, epoch_loss, epoch_acc))
stop_time = timeit.default_timer() # Record the stop time
print("Execution time: " + str(stop_time - start_time) + "\n")
Save training parameters
if epoch % save_epoch == (save_epoch - 1):
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'opt_dict': optimizer.state_dict(),
}, os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth.tar'))
print("Save model at {}\n".format(os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth.tar')))
Load test set
Methods are similar to validation sets , You don't have to calculate the gradient , Update parameters
if useTest and epoch % test_interval == (test_interval - 1):
model.eval()
start_time = timeit.default_timer()
running_loss = 0.0
running_corrects = 0.0
for inputs, labels in tqdm(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = model(inputs)
probs = nn.Softmax(dim=1)(outputs)
preds = torch.max(probs, 1)[1]
loss = criterion(outputs, labels.long())
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / test_size
epoch_acc = running_corrects.double() / test_size
writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch)
print("[test] Epoch: {}/{} Loss: {} Acc: {}".format(epoch+1, nEpochs, epoch_loss, epoch_acc))
stop_time = timeit.default_timer()
print("Execution time: " + str(stop_time - start_time) + "\n")
Pure personal thinking summary , Mistakes are inevitable , Welcome to correct , Thank you for .
边栏推荐
- 6.shell之正则表达式
- Differences between redis and mongodb
- C# 类库的生成,使用类库对象对DataGridView 进行数据绑定
- 一个 DirectShow 播放问题的排查记录
- 基于cornerstone.js的dicom医学影像查看浏览功能
- Basic experiment of microwave technology - Filter Design
- Hucang integrated e-commerce project (II): project use technology, version and basic environment preparation
- 9. Shell text processing three swordsmen awk
- 2021 京东笔试总结
- Introduction to onnx runtime
猜你喜欢

2. Introduce the deployment of lamp platform +discuz Forum

3.跟你思想一样DNS域名解析服务!!!

The idea has been perfectly verified again! The interest rate hike is approaching, and the trend is clear. Take advantage of this wave of market!

CONDA configures the deep learning environment pytorch transformers

Introduction to onnx (open neural network exchange)

云原生IDE:iVX免费的首个通用无代码开发平台

Configuration of OSPF protocol (take Huawei ENSP as an example)

HCIA实验(09)

js 集合

Attention is all you need paper intensive reading notes transformer
随机推荐
Storage, computing, distributed Virtualization (collection and sorting is suitable for Xiaobai)
使用Three.js实现炫酷的赛博朋克风格3D数字地球大屏
Reproduce asvspoof 2021 baseline rawnet2
MySQL offline deployment
Trojang attack on neural networks paper reading notes
Mysql5.7 master-slave database deployment (offline deployment)
2021 qunar written examination summary
Use three.js to realize the cool cyberpunk style 3D digital earth large screen
Redis usage scenario
8.shell文件处理三剑客之sed
JS collection
Voxceleb1 dataset Download
The practice of asynchronous servlet in image service
Kraken中事件通道原理分析
6. PXE combines kickstart principle and configuration to realize unattended automatic installation
Microwave technology homework course design - Discrete capacitance and inductance + microstrip single stub + microstrip double stub
基于cornerstone.js的dicom医学影像查看浏览功能
Configuration of static routes (take Huawei ENSP as an example)
QT | mouse events and wheel events qmouseevent, qwheelevent
Attention is all you need paper intensive reading notes transformer