当前位置:网站首页>Use pytorch to build mobilenetv2 and learn and train based on migration
Use pytorch to build mobilenetv2 and learn and train based on migration
2022-06-25 08:13:00 【@BangBang】
MobileNetV2 The network structure is as follows , For a detailed explanation of the network, please refer to the blog :MobileNet series (2):MobileNet-V2 Network details 
From the network structure of the table, we can see , The model is basically a stacked inverse residual structure (bottleneck), And then through 1x1 Common convolution kernel operation of , The next step is to pool the core into 7x7 Average pooled sampling , Finally through 1x1 Convolution yields the final output . The key to building this network is Inverse residual structure , As long as it is built Inverse residual structure , It is very convenient to build the network .
pytorch The network structures,
stay model.py In file , First, define the basic components of the network .
stay mobilenet v2 Convolution in the network is basically through :Conv+BN+ReLU6 Composed of .
Convolution component
Conv+BN+ReLU6
class ConvBNReLU(nn.Sequential):
def __init__(self,in_channel,out_channel,kernel_size,stride=1,groups=1):
padding=(kernel_size-1) // 2
super(ConvBNReLU,self).__init__(
nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,groups=groups,bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU6(inplace=True)
)
Be careful groups=1 It means that the construction is a normal convolution , If groups be equal to in_channel, So it's going to be DW Convolution . Because to use BN layer , therefore bias It's not used , Set to False
Inverse residual structure
Define a InvertedResidual class , It inherits from nn.Moudle The parent class . The network diagram of inverse residual structure is as follows :
The structure of inverse residual network is similar to that of ordinary residual network , The ordinary residual structure is a structure with thick ends and thin middle , On the contrary, the structure of inverse residuals is thin at both ends and thick in the middle . See :MobileNet series (2):MobileNet-V2 Network details ,DW The number of convolutions is an input channel It's the same , Every DW The convolution layer is responsible for only one channel. So after DW No change after convolution channel Size .
class InvertedResidual(nn.Module):
def __init__(self,in_channel,out_channel,stride,expand_ratio):
super(InvertResidual,self).__init__()
hidden_channel=in_channel*expand_ratio
self.use_shotcut = stride ==1 and in_channel==out_channel
layers= []
if expand_ratio !=1:
# 1x1 Conv
layers.append(ConvBNReLU(in_channel,hidden_channel,kernel_size=1))
layers.extend([
# 3x3 depthwise conv
ConvBNReLU(hidden_channel,hidden_channel,stride=stride,groups=hidden_channel)
# 1x1 Conv (linear)
nn.Conv2d(hidden_channel,out_channel,kernel_size=1,bias=False)
nn.BatchNorm2d(out_channel)
])
self.conv=nn.Sequential(*layers)
def forward(self,x):
if self.use_shotcut:
return x+ self.conv(x)
else:
return self.conv(x)
MobileNet V2 Network structure
Definition MobileNetV2 class , Inherit nn.Module, The complete network construction code is as follows :
class MobileNetV2(nn.Module):
def __init__(self,num_classes=100,alpha=1.0,round_nearest=8):
super(MobileNetV2,self).__init__()
block=InvertedResidual
input_channel=_make_divisible(32*alpha,round_nearest)
last_channel=_make_divisible(1280*alpha,round_nearest)
inverted_residual_setting = [
# t,c,n,s
[1,16,1,1],
[6,24,2,2],
[6,32,3,2],
[6,64,4,2],
[9,96,3,1],
[6,160,3,2],
[6,320,1,1]
]
features = []
# conv1 layer
features.append(ConvBNReLU(3,input_channel,stride=2))
# build inverted residual blocks
for t,c,n,s in inverted_residual_setting:
# adopt _make_divisible Adjust the number of convolution kernels to round_nearest Integer multiple
output_channels= _make_divisible(c*alpha,round_nearest)
for i in range(n):
stride= s if i==0 else 1
features.append(block(input_channel,output_channel,stride,expand_ratio=t))
input_channel=output_channel
# building last several layers
features.append(ConvBNReLU(input_channel,last_channel,1))
#combine feature layers
self.features=nn.Sequential(*features)
#building classifier
self.avgpool=nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(last_channel,num_classes)
)
# weight initialization
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out')
if m.bias is not None:
m.init.zeros_(m.bias)
elif isinstance(m,nn.BatchNorm2d):
nn.init.ones_(m.weight,0,0.01)
nn.init.zeros_(m.bias)
# Positive propagation process
def forward(self,x):
x=self.features(x)
x=self.avgpool(x)
x=torch.flatten(x,1)
x=self.classifier(x)
return x
among _make_divisible function l originate tensorflow Official implementation code :
def _make_divisible(ch,divisor=8,min_ch=None):
""" https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py """
if min_ch is None:
min_ch=divisor
new_ch=max(min_ch,int(ch+divisor/2)//divisor*divisor)
#Make sure that round down dose not go down by more than 10%
if new_ch <0.9 * ch:
new_ch +=divisor
return new_ch
model training
First of all, say , How to download the official pre training model parameters . For example, download. mobilenet Pre training model of
import torchvision.models.mobilenet
Click on torchvision.models.mobilenet Enter the official function definition , Here's one model_urls, This url It is the link to download the pre training weight of the model :
model_urls= {
'mobilenet_v2':'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
}
Copy the model url Go to Xunlei to download , After downloading, it will be saved in the current project directory , And name :mobilenet_v2.pth
Training scripts
train.py
1. import python package
import torch
import torch.nn as nn
from torchvision import transforms,datasets
import json
import os
import torch.optim as optim
from model import MobileNetV2
2. Data preparation
data_transform= {
"train": transforms.Compose([transforms.RandomResizeCrop(224),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
"val":transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
data_root = os.path.abspath(os.path.join(os.getcwd(),'../..')) #get data root path
image_path=data_root +"/data_set/flower_data/" #flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"])
train_num=len(train_dataset)
#{'daisy':'0','dandelion`:1,'roses':2,'sunflower':3,'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict =dict((val,key) for key,value in flower_list.items())
# write dict into json file
json_str=json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
bath_size=16
train_loader=torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,shuffle=True,
num_workers=0)
validate_data=datasets.ImageFolder(root=image_path + "val",
transform=data_transform["val"])
val_num=len(validate_dataset)
validate_loader=torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size,shuffle=False,
num_works=0)
3. Load model
net=MobileNetV2(num_classes=5)
model_weight_path="./mobilenet_v2.pth"
# load pretrain weights
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)
# delete classifier weights
pre_dict=={
k:v for k,v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
# strict = False Indicates that only the matching weights are read
missing_keys,unexpected_keys=net.load_state_dict(pre_dict,strict=False)
# freeze features weights
for param in net.features.parameters():
param.requires_grad=False
net.to(device)
4. Model training
# define loss function
loss_function=nn.CrossEntropyLoss()
# construct an optimizer
params=[p for p in net.parameters() if p.requires_grad]
optimizer=optim.Adam(params,lr=0.0001)
best_acc=0.0
save_path='./MobileNetV2.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
#train
net.train()
running_loss=0.0
train_bar=tqdm(train_loader)
for step,data in enumerate(train_bar):
images,labels=data
optimizer.zero_grad()
logits=net(images.to(device))
loss=loss_function(logits,labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss +=loss.item()
train_bar.desc="train epoch [{} / {}] loss:{:.3f}".format(epoch+1,epochs,loss)
#validate
net.eval()
acc=0.0 #accumulate accurate number / epoch
with torch.no_grad():
val_bar=tqdm(validate_loader)
for val_data in val_bar:
val_images,val_labels=val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs,test_labels)
predict_y= torch.max(outputs,dim=1)[1]
acc += torch.eq(predict_y,val_labels.to(device)).sum().item()
val_bar.desc ="valid epoch [{}/{}]".format(epoch+1,epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
边栏推荐
- 不怕百战失利,就怕灰心丧气
- Logu P2486 [sdoi2011] coloring (tree chain + segment tree + merging of intervals on the tree)
- STM32CubeMX 學習(5)輸入捕獲實驗
- Websocket understanding and application scenarios
- C examples of using colordialog to change text color and fontdialog to change text font
- C disk drives, folders and file operations
- Electronics: Lesson 013 - Experiment 14: Wearable pulsed luminaries
- 电子学:第014课——实验 15:防入侵报警器(第一部分)
- 电子学:第010课——实验 9:时间与电容器
- Introduction to the main functions of the can & canfd comprehensive test and analysis software lkmaster of the new usbcan card can analyzer
猜你喜欢

Ubuntu18下登录mysql 5.7设置root密码

TCP and UDP

Opencv minimum filtering (not limited to images)

Apache CouchDB 代码执行漏洞(CVE-2022-24706 )批量POC

Talk about the future of cloud native database

What is SKU and SPU? What is the difference between SKU and SPU
![洛谷P1073 [NOIP2009 提高组] 最优贸易(分层图+最短路)](/img/cb/046fe4b47898fd6db86edc8a267c34.png)
洛谷P1073 [NOIP2009 提高组] 最优贸易(分层图+最短路)

电子学:第014课——实验 15:防入侵报警器(第一部分)
![Luogu p1073 [noip2009 improvement group] optimal trade (layered diagram + shortest path)](/img/cb/046fe4b47898fd6db86edc8a267c34.png)
Luogu p1073 [noip2009 improvement group] optimal trade (layered diagram + shortest path)

剑指offer刷题(简单等级)
随机推荐
Application of can optical transceiver of ring network redundant can/ optical fiber converter in fire alarm system
TCP and UDP
[supplementary question] 2021 Niuke summer multi school training camp 4-N
php数组函数大全
RMQ interval maximum subscript query, interval maximum
六月集训(第25天) —— 树状数组
【补题】2021牛客暑期多校训练营6-n
TCP的那点玩意儿
CVPR 2022 Oral 2D图像秒变逼真3D物体
Deep learning series 45: overview of image restoration
allgero报错:Program has encountered a problem and must exit. The design will be saved as a .SAV file
现在通过开户经理发的开户链接股票开户安全吗?
ffmpeg+SDL2实现音频播放
Ffmpeg+sdl2 for audio playback
电子学:第014课——实验 15:防入侵报警器(第一部分)
c#磁盘驱动器及文件夹还有文件类的操作
TCP MIN_RTO 辩证考
MySQL simple permission management
牛客:飞行路线(分层图+最短路)
How to create a new branch with SVN