当前位置:网站首页>Knowledge distillation -- pytorch implementation
Knowledge distillation -- pytorch implementation
2022-07-27 00:17:00 【Ap21ril】
Lightweight networks
Knowledge distillation can be understood as a part of lightweight network tricks, Lightweight network is a major development trend of deep learning , Especially at the mobile end , Terminal edge computing is a scene that requires computing power and computing time .
Lightweight network can be realized in the following four ways :
1. Compress the trained model : Distillation of knowledge , Weights of quantitative , prune , Attention shift
2. Direct training of lightweight networks :SqueezeNet,MobileNet etc.
3. Speed up convolution : Low rank decomposition
4. Hardware deployment :Tensorrt,Jetson,Openvino etc.
Distillation of knowledge
Knowledge distillation has a high position in lightweight network . The following figure shows the implementation process of knowledge distillation .
Import package
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
# Set random number seed , Easy to reproduce
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use cudnn Speed up convolution
torch.backends.cudnn.benchmark = True
load MNIST Data sets
from torchvision.transforms.transforms import ToTensor
# Load training set
train_dataset = torchvision.datasets.MNIST(
root = 'dataset/',
train = True,
transform = transforms.ToTensor(),
download=True
)
# Generating test sets
test_dataset = torchvision.datasets.MNIST(
root = 'dataset/',
train = False,
transform = transforms.ToTensor(),
download=True
)
# Generate dataloader
train_dataloader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=32,shuffle=True)
Building a teacher model
class TeacherModel(nn.Module):
def __init__(self,in_channels=1,num_classes=10):
super(TeacherModel,self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784,1200)
self.fc2 = nn.Linear(1200,1200)
self.fc3 = nn.Linear(1200,num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self,x):
x = x.view(-1,784)
x = self.fc1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
Training teacher model
model = TeacherModel()
model = model.to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 6
for epoch in range(epochs):
model.train()
# Train on the training set
for data, targets in tqdm(train_dataloader):
data = data.to(device)
targets = targets.to(device)
# Forward prediction
preds = model(data)
loss = criterion(preds,targets)
# Back propagation , Optimize weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate performance on test sets
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x,y in test_dataloader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions==y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
Teacher model prediction results

Create a student model
class StudentModel(nn.Module):
def __init__(self,in_channels=1,num_classes=10):
super(StudentModel,self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(784,20)
self.fc2 = nn.Linear(20,20)
self.fc3 = nn.Linear(20,num_classes)
self.dropout = nn.Dropout(p=0.5)
def forward(self,x):
x = x.view(-1,784)
x = self.fc1(x)
# x = self.dropout(x)
x = self.relu(x)
x = self.fc2(x)
# x = self.dropout(x)
x = self.relu(x)
x = self.fc3(x)
return x
Train students to model
model = StudentModel()
model = model.to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 6
for epoch in range(epochs):
model.train()
# Train on the training set
for data, targets in tqdm(train_dataloader):
data = data.to(device)
targets = targets.to(device)
# Forward prediction
preds = model(data)
loss = criterion(preds,targets)
# Back propagation , Optimize weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate performance on test sets
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x,y in test_dataloader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions==y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
Student model prediction results

The student model is lighter than the teacher model ( The hidden layers of the teacher model are 1200 Neurons , Student model only 20 Neurons ), So the performance is not as good as the teacher model
student_model_scratch = model
Knowledge distillation training model
# Prepare the pre trained teacher model
teacher_model.eval()
# Prepare a new student model
model = StudentModel()
model = model.to(device)
model.train()
# Distillation temperature
temp = 7
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss The weight
alpha = 0.3
# soft_loss
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 10
for epoch in range(epochs):
# Training model weight on the training set
for data,targets in tqdm(train_dataloader):
data = data.to(device)
targets = targets.to(device)
# The teacher model predicts
with torch.no_grad():
teachers_preds = teacher_model(data)
# Student model prediction
students_preds = model(data)
# Calculation hard_loss
students_loss = hard_loss(students_preds,targets)
# Calculate the predicted results after distillation and soft_loss
ditillation_loss = soft_loss(
F.softmax(students_preds/temp,dim=1),
F.softmax(teachers_preds/temp,dim=1)
)
# take hard_loss and soft_loss To sum by weight
loss = alpha*students_loss+(1-alpha)*ditillation_loss
# Back propagation , Optimize weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate model performance on test sets
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x,y in test_dataloader:
x = x.to(device)
y = y.to(device)
preds = model(x)
predictions = preds.max(1).indices
num_correct += (predictions==y).sum()
num_samples += predictions.size(0)
acc = (num_correct/num_samples).item()
model.train()
print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))
Prediction results after knowledge distillation training

Although the results are similar , But this is just a small application of knowledge distillation , And that is MNIST Not a lot of data , So the difference is not obvious , But you can better understand the knowledge distillation model through this code . Knowledge distillation is definitely the digging work of lightweight network .
边栏推荐
- New features of ES6
- 4. Talk about the famous Zhang Zhengyou calibration method
- Codeforces D. two divisors (number theory, linear sieve)
- Recbole use 1
- Deep learning of parameter adjustment skills
- Familiarize you with the "phone book" of cloud network: DNS
- 蒙着头配置deeplabcut2
- When aw9523b chip is used to drive 16 channel led, the LED is wrongly lit
- Embedded system migration [8] - device tree and root file system migration
- Upload files to OSS file server
猜你喜欢
随机推荐
Recent answers - column
Dynamic memory management
Abstract classes and interfaces (sorting out some knowledge points)
DHCP, VLAN, NAT, large comprehensive experiment
Chapter 2 develop user traffic interceptors
Azure synapse analytics Performance Optimization Guide (4) -- optimize performance using result set caching
Design of electronic scale based on 51 single chip microcomputer
[step by step, even thousands of miles] key words in the specified time period of the statistical log
15_ Key function and principle
08 design of intelligent agricultural environmental monitoring system based on ZigBee
14_ Basic list
Leetcode topic - array
20220720折腾deeplabcut2
Codeforces C1. Simple Polygon Embedding
2022.7.26-----leetcode.1206
在pycharm中部署yolov5报错问题
Push to origin/master was rejected error resolution
第1章 需求分析与ssm环境准备
The place where the dream begins ---- first knowing C language (2)
C and pointer Chapter 18 runtime environment 18.1 judgment of runtime environment









