当前位置:网站首页>A code example of the PCL method in the domain of DG (Domain Generalization)
A code example of the PCL method in the domain of DG (Domain Generalization)
2022-08-04 06:19:00 【nuomi666】
Share your articlesPCL: Proxy-based Contrastive Learning for Domain Generalization,代码已经在GitHub上已经开源,Its use is inDomainBedRealize the optimization of the frame on the basisSWADOn the framework of this paper mainly put some compact code,For each module only retained an algorithm.
DomainBedIf master toDGThe realization of a variety of methods in the field of,So the framework to write very complicated,封装了很多东西,For the first time to use classmates really unfriendly,Even may look not to understand even the input and output,如果对DG和DA感兴趣的同学,Here recommend a bosses to realizeDA和DG的库,Migration study code base,比较容易看懂!!
话不多说,直接上代码
主文件
main.py
import torch
import algorithm
from torch.autograd import Variable
from torchvision import datasets, transforms
#使用swad调优的话,Asynchronous finally
#import swa_utils
#import swad as swad_module
import torch.nn as nn
train_transforms= transforms.Compose([
transforms.Resize(256),
transforms.RandomRotation((5), expand=True),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(.3, .3, .3, .3),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_dataTrans = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_data_dir = '../../data/train'
val_data_dir = '../../data/val'
test_data_dir='../../data/test'
train_dataset = datasets.ImageFolder(train_data_dir, train_transforms)
val_dataset = datasets.ImageFolder(val_data_dir,val_dataTrans)
test_dataset = datasets.ImageFolder(test_data_dir, val_dataTrans)
#According to the need to partition data set
# train_dataset = torch.utils.data.ConcatDataset([train_dataset1, test_dataset])
# val_dataset = datasets.ImageFolder(val_data_dir, _dataTrans)
# val_dataset, val_dataset_ = torch.utils.data.random_split(val_dataset, [5, len(val_dataset) - 5])
# train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_])
train_dataloder = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True)
val_dataloder = torch.utils.data.DataLoader(val_dataset,batch_size=4,shuffle=True)
device = "cuda" if torch.cuda.is_available() else"cpu"
# setup hparams
algorithm = algorithm.ERM(input_shape=[3, 244, 244], num_classes=4)
use_swad=False#是否使用swad优化
#The choice of the optimizer inalgorithm文件里
if use_swad:
swad_algorithm = swa_utils.AveragedModel(algorithm)
swad_cls = getattr(swad_module, 'LossValley')
swad_kwargs={
'n_converge': 3, 'n_tolerance': 6, 'tolerance_ratio': 0.3}
swad = swad_cls(**swad_kwargs)
algorithm.to(device)
lossfunc=nn.CrossEntropyLoss()
epochs=10
if __name__ == '__main__':
for epoch in range(epochs):
running_loss = 0
running_corrects = 0
algorithm.train()
for step ,(inputs, labels) in enumerate(train_dataloder):
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
step_vals = algorithm.update(inputs, labels)
if use_swad:
swad_algorithm.update_parameters(algorithm, step=step)
_, outputs = algorithm.predict(inputs)
_, preds = torch.max(outputs.data, 1)
train_loss = lossfunc(outputs, labels)
# statistics
running_loss += loss.data
train_acc=torch.sum(preds == labels.data).cpu().to(torch.float32)
running_corrects += train_acc
tr_epoch_loss = running_loss / len(train_dataloder)
tr_epoch_acc = running_corrects / len(train_dataloder)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, tr_epoch_loss, tr_epoch_acc))
with torch.no_grad():
algorithm.eval()
running_loss = 0
running_corrects = 0
for step, (inputs, labels) in enumerate(val_dataloder ):
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
_, outputs = algorithm.predict(inputs)
_, preds = torch.max(outputs.data, 1)
loss = lossfunc(outputs, labels)
# statistics
running_loss += loss.data
val_acc = torch.sum(preds == labels.data).cpu().to(torch.float32)
running_corrects += val_acc
te_epoch_loss = running_loss / len(val_dataloder)
te_epoch_acc = running_corrects / len(val_dataloder)
if use_swad:
swad.update_and_evaluate(swad_algorithm, te_epoch_acc)
swad_algorithm = swa_utils.AveragedModel(algorithm) # reset
filename = r'epoch{}_Loss{:.4f}_Acc{:.4f}_Loss{:.4f}_Acc{:.4f}.pth'.format(
epoch, tr_epoch_loss, tr_epoch_acc, te_epoch_loss, te_epoch_acc)
torch.save(algorithm.state_dict(), filename, _use_new_zipfile_serialization=False)
主算法
The main algorithm part,这里使用的是Empirical Risk Minimization (ERM, Vapnik, 1998),原DomainBedFramework provides many algorithm,如IRM、GroupDRO、RSC等,Can according to need to take.
————————————————————
algorithm.py
import math
from model import *
from losses import ProxyLoss, ProxyPLoss
import torch
class ERM(torch.nn.Module):
""" Empirical Risk Minimization (ERM) """
def __init__(self, input_shape, num_classes):
super(ERM, self).__init__()
self.encoder, self.scale, self.pcl_weights = encoder()
self._initialize_weights(self.encoder)
self.fea_proj, self.fc_proj = fea_proj()
nn.init.kaiming_uniform_(self.fc_proj, mode='fan_out', a=math.sqrt(5))
self.featurizer = ResNet()
self.classifier = nn.Parameter(torch.FloatTensor(num_classes,256))
nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5))
self.optimizer = torch.optim.Adam([
{
'params': self.featurizer.parameters()},
{
'params': self.encoder.parameters()},
{
'params': self.fea_proj.parameters()},
{
'params': self.fc_proj},
{
'params': self.classifier},
], lr=0.002, weight_decay=0.0)
self.proxycloss = ProxyPLoss(num_classes=num_classes, scale=self.scale)
def _initialize_weights(self, modules):
for m in modules:
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def update(self, x, y, **kwargs):
all_x = x
all_y = y
rep, pred = self.predict(all_x)
loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y)
fc_proj = F.linear(self.classifier, self.fc_proj)
assert fc_proj.requires_grad == True
loss_pcl = self.proxycloss(rep, all_y, fc_proj)
loss = loss_cls + self.pcl_weights * loss_pcl
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {
"loss_cls": loss_cls.item(), "loss_pcl": loss_pcl.item()}
def predict(self, x):
x = self.featurizer(x)
x = self.encoder(x)
rep = self.fea_proj(x)
pred = F.linear(x, self.classifier)
return rep, pred
网络结构
Next is the main network structure module,这里是用的ResNet50For the feature extraction of images,Then use the connection layer forencoder
————————————————————
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
class Identity(nn.Module):
"""An identity layer"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class SqueezeLastTwo(nn.Module):
""" A module which squeezes the last two dimensions, ordinary squeeze can be a problem for batch size 1 """
def __init__(self):
super(SqueezeLastTwo, self).__init__()
def forward(self, x):
return x.view(x.shape[0], x.shape[1])
class ResNet(torch.nn.Module):
"""ResNet with the softmax chopped off and the batchnorm frozen"""
def __init__(self):
super(ResNet, self).__init__()
#If you want to use other network for feature extraction,可以在这里改
#But to take the followingencoderModule layer of all connections of the input and output new network last full connection is the same
network = torchvision.models.resnet50(pretrained=False)
# network = resnet50(pretrained=hparams["pretrained"])
self.network = network
# adapt number of channels
# save memory
# del self.network.fc
#The new network output layer replacement is empty,用来提供encoder的接口
#tips;The last layer is mostlymodel.fc或model.head
self.network.fc = Identity()
self.dropout = nn.Dropout(0.1)
self.freeze_bn()
def forward(self, x):
"""Encode x into a feature vector of size n_outputs."""
return self.dropout(self.network(x))
def train(self, mode=True):
""" Override the default train() to freeze the BN parameters """
super().train(mode)
self.freeze_bn()
def freeze_bn(self):
for m in self.network.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def encoder():
scale_weights = 12
pcl_weights = 1
dropout = nn.Dropout(0.25)
hidden_size = 512
out_dim = 256
#In the new network should pay attention to change here
n_outputs = 2048
encoder = nn.Sequential(
nn.Linear(n_outputs, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
dropout,
nn.Linear(hidden_size, out_dim),
)
return encoder, scale_weights, pcl_weights
def fea_proj():
dropout = nn.Dropout(0.25)
hidden_size = 256
out_dim = 256
fea_proj = nn.Sequential(
nn.Linear(out_dim,
out_dim),
)
fc_proj = nn.Parameter(
torch.FloatTensor(out_dim,
out_dim)
)
return fea_proj, fc_proj
损失函数
PCL中提到的损失函数
————————————————————————
losses.py
# coding: utf-8
''' custom loss function '''
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# # ========================= proxy Contrastive loss ==========================
class ProxyLoss(nn.Module):
''' pass '''
def __init__(self, scale=1, thres=0.1):
super(ProxyLoss, self).__init__()
self.scale = scale
self.thres = thres
def forward(self, feature, pred, target):
feature = F.normalize(feature, p=2, dim=1) # normalize
feature = torch.matmul(feature, feature.transpose(1, 0)) # (B, B)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
feature = feature * ~label_matrix # get negative matrix
feature = feature.masked_fill(feature < self.thres, -np.inf)
pred = torch.cat([pred, feature], dim=1) # (N, C+N)
loss = F.nll_loss(F.log_softmax(self.scale * pred, dim=1), \
target)
return loss
class ProxyPLoss(nn.Module):
''' pass '''
def __init__(self, num_classes, scale):
super(ProxyPLoss, self).__init__()
self.soft_plus = nn.Softplus()
self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
self.scale = scale
def forward(self, feature, target, proxy):
feature = F.normalize(feature, p=2, dim=1)
pred = F.linear(feature, F.normalize(proxy, p=2, dim=1)) # (N, C)
label = (self.label.unsqueeze(1) == target.unsqueeze(0))
pred = torch.masked_select(pred.transpose(1, 0), label)
pred = pred.unsqueeze(1)
feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
index_label = torch.LongTensor([i for i in range(feature.shape[0])]) # generate index label
index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0) # get index matrix
feature = feature * ~label_matrix # get negative matrix
feature = feature.masked_fill(feature < 1e-6, -np.inf)
logits = torch.cat([pred, feature], dim=1) # (N, C+N)
label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
return loss
class PosAlign(nn.Module):
''' pass '''
def __init__(self):
super(PosAlign, self).__init__()
self.soft_plus = nn.Softplus()
def forward(self, feature, target):
feature = F.normalize(feature, p=2, dim=1)
feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
positive_pair = torch.masked_select(feature, label_matrix)
# print("positive_pair.shape", positive_pair.shape)
loss = 1. * self.soft_plus(torch.logsumexp(positive_pair, 0))
return loss
SWAD调参
如果要使用SWAD调参,Can go fromGitHub自取swad和swa_utils.
After the line by line analysis again when you have time…
边栏推荐
- Data reading in yolov3 (1)
- 浅谈游戏音效测试点
- Use of double pointers
- 典型CCN网络——efficientNet(2019-Google-已开源)
- ConnectionRefusedError: [Errno 111] Connection refused问题解决
- Pytorch问题总结
- target has libraries with conflicting names: libcrypto.a and libssl.a.
- AWS使用EC2降低DeepRacer的训练成本:DeepRacer-for-cloud的实践操作
- Postgresql snapshot
- 迅雷关闭自动更新
猜你喜欢
随机推荐
The use of the attribute of the use of the animation and ButterKnife
【CV-Learning】语义分割
[CV-Learning] Convolutional Neural Network Preliminary Knowledge
PP-LiteSeg
2020-10-29
Learning curve learning_curve function in sklearn
2020-10-19
Introduction of linear regression 01 - API use cases
"A minute" Copy siege lion log 】 【 run MindSpore LeNet model
pytorch学习-没掌握的点
亚马逊云科技Build On-Amazon Neptune基于知识图谱的推荐模型构建心得
软著撰写注意事项
MFC 打开与保存点云PCD文件
TensorFlow2学习笔记:4、第一个神经网模型,鸢尾花分类
Amazon Cloud Technology Build On 2022 - AIot Season 2 IoT Special Experiment Experience
fuser 使用—— YOLOV5内存溢出——kill nvidai-smi 无pid 的 GPU 进程
SQL注入详解
Android foundation [Super detailed android storage method analysis (SharedPreferences, SQLite database storage)]
动手学深度学习_线性回归
Androd Day02