当前位置:网站首页>A code example of the PCL method in the domain of DG (Domain Generalization)
A code example of the PCL method in the domain of DG (Domain Generalization)
2022-08-04 06:19:00 【nuomi666】
Share your articlesPCL: Proxy-based Contrastive Learning for Domain Generalization,代码已经在GitHub上已经开源,Its use is inDomainBedRealize the optimization of the frame on the basisSWADOn the framework of this paper mainly put some compact code,For each module only retained an algorithm.
DomainBedIf master toDGThe realization of a variety of methods in the field of,So the framework to write very complicated,封装了很多东西,For the first time to use classmates really unfriendly,Even may look not to understand even the input and output,如果对DG和DA感兴趣的同学,Here recommend a bosses to realizeDA和DG的库,Migration study code base,比较容易看懂!!
话不多说,直接上代码
主文件
main.py
import torch
import algorithm
from torch.autograd import Variable
from torchvision import datasets, transforms
#使用swad调优的话,Asynchronous finally
#import swa_utils
#import swad as swad_module
import torch.nn as nn
train_transforms= transforms.Compose([
transforms.Resize(256),
transforms.RandomRotation((5), expand=True),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(.3, .3, .3, .3),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_dataTrans = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_data_dir = '../../data/train'
val_data_dir = '../../data/val'
test_data_dir='../../data/test'
train_dataset = datasets.ImageFolder(train_data_dir, train_transforms)
val_dataset = datasets.ImageFolder(val_data_dir,val_dataTrans)
test_dataset = datasets.ImageFolder(test_data_dir, val_dataTrans)
#According to the need to partition data set
# train_dataset = torch.utils.data.ConcatDataset([train_dataset1, test_dataset])
# val_dataset = datasets.ImageFolder(val_data_dir, _dataTrans)
# val_dataset, val_dataset_ = torch.utils.data.random_split(val_dataset, [5, len(val_dataset) - 5])
# train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_])
train_dataloder = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True)
val_dataloder = torch.utils.data.DataLoader(val_dataset,batch_size=4,shuffle=True)
device = "cuda" if torch.cuda.is_available() else"cpu"
# setup hparams
algorithm = algorithm.ERM(input_shape=[3, 244, 244], num_classes=4)
use_swad=False#是否使用swad优化
#The choice of the optimizer inalgorithm文件里
if use_swad:
swad_algorithm = swa_utils.AveragedModel(algorithm)
swad_cls = getattr(swad_module, 'LossValley')
swad_kwargs={
'n_converge': 3, 'n_tolerance': 6, 'tolerance_ratio': 0.3}
swad = swad_cls(**swad_kwargs)
algorithm.to(device)
lossfunc=nn.CrossEntropyLoss()
epochs=10
if __name__ == '__main__':
for epoch in range(epochs):
running_loss = 0
running_corrects = 0
algorithm.train()
for step ,(inputs, labels) in enumerate(train_dataloder):
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
step_vals = algorithm.update(inputs, labels)
if use_swad:
swad_algorithm.update_parameters(algorithm, step=step)
_, outputs = algorithm.predict(inputs)
_, preds = torch.max(outputs.data, 1)
train_loss = lossfunc(outputs, labels)
# statistics
running_loss += loss.data
train_acc=torch.sum(preds == labels.data).cpu().to(torch.float32)
running_corrects += train_acc
tr_epoch_loss = running_loss / len(train_dataloder)
tr_epoch_acc = running_corrects / len(train_dataloder)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, tr_epoch_loss, tr_epoch_acc))
with torch.no_grad():
algorithm.eval()
running_loss = 0
running_corrects = 0
for step, (inputs, labels) in enumerate(val_dataloder ):
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
_, outputs = algorithm.predict(inputs)
_, preds = torch.max(outputs.data, 1)
loss = lossfunc(outputs, labels)
# statistics
running_loss += loss.data
val_acc = torch.sum(preds == labels.data).cpu().to(torch.float32)
running_corrects += val_acc
te_epoch_loss = running_loss / len(val_dataloder)
te_epoch_acc = running_corrects / len(val_dataloder)
if use_swad:
swad.update_and_evaluate(swad_algorithm, te_epoch_acc)
swad_algorithm = swa_utils.AveragedModel(algorithm) # reset
filename = r'epoch{}_Loss{:.4f}_Acc{:.4f}_Loss{:.4f}_Acc{:.4f}.pth'.format(
epoch, tr_epoch_loss, tr_epoch_acc, te_epoch_loss, te_epoch_acc)
torch.save(algorithm.state_dict(), filename, _use_new_zipfile_serialization=False)
主算法
The main algorithm part,这里使用的是Empirical Risk Minimization (ERM, Vapnik, 1998),原DomainBedFramework provides many algorithm,如IRM、GroupDRO、RSC等,Can according to need to take.
————————————————————
algorithm.py
import math
from model import *
from losses import ProxyLoss, ProxyPLoss
import torch
class ERM(torch.nn.Module):
""" Empirical Risk Minimization (ERM) """
def __init__(self, input_shape, num_classes):
super(ERM, self).__init__()
self.encoder, self.scale, self.pcl_weights = encoder()
self._initialize_weights(self.encoder)
self.fea_proj, self.fc_proj = fea_proj()
nn.init.kaiming_uniform_(self.fc_proj, mode='fan_out', a=math.sqrt(5))
self.featurizer = ResNet()
self.classifier = nn.Parameter(torch.FloatTensor(num_classes,256))
nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5))
self.optimizer = torch.optim.Adam([
{
'params': self.featurizer.parameters()},
{
'params': self.encoder.parameters()},
{
'params': self.fea_proj.parameters()},
{
'params': self.fc_proj},
{
'params': self.classifier},
], lr=0.002, weight_decay=0.0)
self.proxycloss = ProxyPLoss(num_classes=num_classes, scale=self.scale)
def _initialize_weights(self, modules):
for m in modules:
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def update(self, x, y, **kwargs):
all_x = x
all_y = y
rep, pred = self.predict(all_x)
loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y)
fc_proj = F.linear(self.classifier, self.fc_proj)
assert fc_proj.requires_grad == True
loss_pcl = self.proxycloss(rep, all_y, fc_proj)
loss = loss_cls + self.pcl_weights * loss_pcl
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {
"loss_cls": loss_cls.item(), "loss_pcl": loss_pcl.item()}
def predict(self, x):
x = self.featurizer(x)
x = self.encoder(x)
rep = self.fea_proj(x)
pred = F.linear(x, self.classifier)
return rep, pred
网络结构
Next is the main network structure module,这里是用的ResNet50For the feature extraction of images,Then use the connection layer forencoder
————————————————————
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
class Identity(nn.Module):
"""An identity layer"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class SqueezeLastTwo(nn.Module):
""" A module which squeezes the last two dimensions, ordinary squeeze can be a problem for batch size 1 """
def __init__(self):
super(SqueezeLastTwo, self).__init__()
def forward(self, x):
return x.view(x.shape[0], x.shape[1])
class ResNet(torch.nn.Module):
"""ResNet with the softmax chopped off and the batchnorm frozen"""
def __init__(self):
super(ResNet, self).__init__()
#If you want to use other network for feature extraction,可以在这里改
#But to take the followingencoderModule layer of all connections of the input and output new network last full connection is the same
network = torchvision.models.resnet50(pretrained=False)
# network = resnet50(pretrained=hparams["pretrained"])
self.network = network
# adapt number of channels
# save memory
# del self.network.fc
#The new network output layer replacement is empty,用来提供encoder的接口
#tips;The last layer is mostlymodel.fc或model.head
self.network.fc = Identity()
self.dropout = nn.Dropout(0.1)
self.freeze_bn()
def forward(self, x):
"""Encode x into a feature vector of size n_outputs."""
return self.dropout(self.network(x))
def train(self, mode=True):
""" Override the default train() to freeze the BN parameters """
super().train(mode)
self.freeze_bn()
def freeze_bn(self):
for m in self.network.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def encoder():
scale_weights = 12
pcl_weights = 1
dropout = nn.Dropout(0.25)
hidden_size = 512
out_dim = 256
#In the new network should pay attention to change here
n_outputs = 2048
encoder = nn.Sequential(
nn.Linear(n_outputs, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
dropout,
nn.Linear(hidden_size, out_dim),
)
return encoder, scale_weights, pcl_weights
def fea_proj():
dropout = nn.Dropout(0.25)
hidden_size = 256
out_dim = 256
fea_proj = nn.Sequential(
nn.Linear(out_dim,
out_dim),
)
fc_proj = nn.Parameter(
torch.FloatTensor(out_dim,
out_dim)
)
return fea_proj, fc_proj
损失函数
PCL中提到的损失函数
————————————————————————
losses.py
# coding: utf-8
''' custom loss function '''
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# # ========================= proxy Contrastive loss ==========================
class ProxyLoss(nn.Module):
''' pass '''
def __init__(self, scale=1, thres=0.1):
super(ProxyLoss, self).__init__()
self.scale = scale
self.thres = thres
def forward(self, feature, pred, target):
feature = F.normalize(feature, p=2, dim=1) # normalize
feature = torch.matmul(feature, feature.transpose(1, 0)) # (B, B)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
feature = feature * ~label_matrix # get negative matrix
feature = feature.masked_fill(feature < self.thres, -np.inf)
pred = torch.cat([pred, feature], dim=1) # (N, C+N)
loss = F.nll_loss(F.log_softmax(self.scale * pred, dim=1), \
target)
return loss
class ProxyPLoss(nn.Module):
''' pass '''
def __init__(self, num_classes, scale):
super(ProxyPLoss, self).__init__()
self.soft_plus = nn.Softplus()
self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
self.scale = scale
def forward(self, feature, target, proxy):
feature = F.normalize(feature, p=2, dim=1)
pred = F.linear(feature, F.normalize(proxy, p=2, dim=1)) # (N, C)
label = (self.label.unsqueeze(1) == target.unsqueeze(0))
pred = torch.masked_select(pred.transpose(1, 0), label)
pred = pred.unsqueeze(1)
feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
index_label = torch.LongTensor([i for i in range(feature.shape[0])]) # generate index label
index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0) # get index matrix
feature = feature * ~label_matrix # get negative matrix
feature = feature.masked_fill(feature < 1e-6, -np.inf)
logits = torch.cat([pred, feature], dim=1) # (N, C+N)
label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
return loss
class PosAlign(nn.Module):
''' pass '''
def __init__(self):
super(PosAlign, self).__init__()
self.soft_plus = nn.Softplus()
def forward(self, feature, target):
feature = F.normalize(feature, p=2, dim=1)
feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
positive_pair = torch.masked_select(feature, label_matrix)
# print("positive_pair.shape", positive_pair.shape)
loss = 1. * self.soft_plus(torch.logsumexp(positive_pair, 0))
return loss
SWAD调参
如果要使用SWAD调参,Can go fromGitHub自取swad和swa_utils.
After the line by line analysis again when you have time…
边栏推荐
- 机器学习——分类问题对于文字标签的处理(特征工程)
- MNIST手写数字识别 —— 从二分类到十分类
- 【CV-Learning】Convolutional Neural Network
- 代码庆端午--粽你心意
- SQL注入详解
- 光条中心提取方法总结(一)
- MySQL leftmost prefix principle [I understand hh]
- Install dlib step pit record, error: WARNING: pip is configured with locations that require TLS/SSL
- PP-LiteSeg
- AIDL communication between two APPs
猜你喜欢

在AWS-EC2中安装Minikube集群

【CV-Learning】图像分类

fuser 使用—— YOLOV5内存溢出——kill nvidai-smi 无pid 的 GPU 进程

fill_between in Matplotlib; np.argsort() function

图像形变(插值方法)

【论文阅读】Multi-View Spectral Clustering with Optimal Neighborhood Laplacian Matrix

亚马逊云科技 Build On 2022 - AIot 第二季物联网专场实验心得
![[CV-Learning] Semantic Segmentation](/img/ad/ff5076495fa68e4bbf3be78f5ac6f2.png)
[CV-Learning] Semantic Segmentation

CSDN大礼包--高校圆桌派大礼包

Deep Adversarial Decomposition: A Unified Framework for Separating Superimposed Images
随机推荐
【论文阅读】Exploring Spatial Significance via Hybrid Pyramidal Graph Network for Vehicle Re-identificatio
Comparison of oracle's number and postgresql's numeric
Halcon缺陷检测
【CV-Learning】Image Classification
[CV-Learning] Semantic Segmentation
Vision Transformer 论文 + 详解( ViT )
AWS使用EC2降低DeepRacer的训练成本:DeepRacer-for-cloud的实践操作
【CV-Learning】线性分类器(SVM基础)
CSDN大礼包--高校圆桌派大礼包
Attention Is All You Need(Transformer)
MFC读取点云,只能正常显示第一个,显示后面时报错
TensorFlow2 study notes: 5. Common activation functions
MNIST手写数字识别 —— 从零构建感知机实现二分类
MNIST手写数字识别 —— ResNet-经典卷积神经网络
【CV-Learning】图像分类
read and study
腾讯、网易纷纷出手,火到出圈的元宇宙到底是个啥?
【CV-Learning】目标检测&实例分割
【论文阅读】Mining Cross-Image Semantics for Weakly Supervised Semantic Segmentation
Usage of Thread, Handler and IntentService