当前位置:网站首页>关于DG(域泛化)领域的PCL方法的代码实例
关于DG(域泛化)领域的PCL方法的代码实例
2022-08-04 05:30:00 【nuomi666】
分享一下文章PCL: Proxy-based Contrastive Learning for Domain Generalization,代码已经在GitHub上已经开源,其使用的是在DomainBed框架基础上实现的优化框架SWAD上改进的框架本文主要就放一些精简的代码,对于每个模块只保留了一个算法。
DomainBed库主要是为了DG领域的多种方法的实现,所以框架写的很复杂,封装了很多东西,对初次使用的同学真的很不友好,甚至可能连输入输出都看不懂,如果对DG和DA感兴趣的同学,这里推荐一个大佬实现的DA和DG的库,迁移学习代码库,比较容易看懂!!
话不多说,直接上代码
主文件
main.py
import torch
import algorithm
from torch.autograd import Variable
from torchvision import datasets, transforms
#使用swad调优的话,异步文章最后
#import swa_utils
#import swad as swad_module
import torch.nn as nn
train_transforms= transforms.Compose([
transforms.Resize(256),
transforms.RandomRotation((5), expand=True),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(.3, .3, .3, .3),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_dataTrans = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_data_dir = '../../data/train'
val_data_dir = '../../data/val'
test_data_dir='../../data/test'
train_dataset = datasets.ImageFolder(train_data_dir, train_transforms)
val_dataset = datasets.ImageFolder(val_data_dir,val_dataTrans)
test_dataset = datasets.ImageFolder(test_data_dir, val_dataTrans)
#根据需要可重新划分数据集
# train_dataset = torch.utils.data.ConcatDataset([train_dataset1, test_dataset])
# val_dataset = datasets.ImageFolder(val_data_dir, _dataTrans)
# val_dataset, val_dataset_ = torch.utils.data.random_split(val_dataset, [5, len(val_dataset) - 5])
# train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_])
train_dataloder = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True)
val_dataloder = torch.utils.data.DataLoader(val_dataset,batch_size=4,shuffle=True)
device = "cuda" if torch.cuda.is_available() else"cpu"
# setup hparams
algorithm = algorithm.ERM(input_shape=[3, 244, 244], num_classes=4)
use_swad=False#是否使用swad优化
#优化器的选择在algorithm文件里
if use_swad:
swad_algorithm = swa_utils.AveragedModel(algorithm)
swad_cls = getattr(swad_module, 'LossValley')
swad_kwargs={
'n_converge': 3, 'n_tolerance': 6, 'tolerance_ratio': 0.3}
swad = swad_cls(**swad_kwargs)
algorithm.to(device)
lossfunc=nn.CrossEntropyLoss()
epochs=10
if __name__ == '__main__':
for epoch in range(epochs):
running_loss = 0
running_corrects = 0
algorithm.train()
for step ,(inputs, labels) in enumerate(train_dataloder):
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
step_vals = algorithm.update(inputs, labels)
if use_swad:
swad_algorithm.update_parameters(algorithm, step=step)
_, outputs = algorithm.predict(inputs)
_, preds = torch.max(outputs.data, 1)
train_loss = lossfunc(outputs, labels)
# statistics
running_loss += loss.data
train_acc=torch.sum(preds == labels.data).cpu().to(torch.float32)
running_corrects += train_acc
tr_epoch_loss = running_loss / len(train_dataloder)
tr_epoch_acc = running_corrects / len(train_dataloder)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, tr_epoch_loss, tr_epoch_acc))
with torch.no_grad():
algorithm.eval()
running_loss = 0
running_corrects = 0
for step, (inputs, labels) in enumerate(val_dataloder ):
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
_, outputs = algorithm.predict(inputs)
_, preds = torch.max(outputs.data, 1)
loss = lossfunc(outputs, labels)
# statistics
running_loss += loss.data
val_acc = torch.sum(preds == labels.data).cpu().to(torch.float32)
running_corrects += val_acc
te_epoch_loss = running_loss / len(val_dataloder)
te_epoch_acc = running_corrects / len(val_dataloder)
if use_swad:
swad.update_and_evaluate(swad_algorithm, te_epoch_acc)
swad_algorithm = swa_utils.AveragedModel(algorithm) # reset
filename = r'epoch{}_Loss{:.4f}_Acc{:.4f}_Loss{:.4f}_Acc{:.4f}.pth'.format(
epoch, tr_epoch_loss, tr_epoch_acc, te_epoch_loss, te_epoch_acc)
torch.save(algorithm.state_dict(), filename, _use_new_zipfile_serialization=False)
主算法
主算法部分,这里使用的是Empirical Risk Minimization (ERM, Vapnik, 1998),原DomainBed框架提供了很多算法,如IRM、GroupDRO、RSC等,可以根据需要自行取用。
————————————————————
algorithm.py
import math
from model import *
from losses import ProxyLoss, ProxyPLoss
import torch
class ERM(torch.nn.Module):
""" Empirical Risk Minimization (ERM) """
def __init__(self, input_shape, num_classes):
super(ERM, self).__init__()
self.encoder, self.scale, self.pcl_weights = encoder()
self._initialize_weights(self.encoder)
self.fea_proj, self.fc_proj = fea_proj()
nn.init.kaiming_uniform_(self.fc_proj, mode='fan_out', a=math.sqrt(5))
self.featurizer = ResNet()
self.classifier = nn.Parameter(torch.FloatTensor(num_classes,256))
nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5))
self.optimizer = torch.optim.Adam([
{
'params': self.featurizer.parameters()},
{
'params': self.encoder.parameters()},
{
'params': self.fea_proj.parameters()},
{
'params': self.fc_proj},
{
'params': self.classifier},
], lr=0.002, weight_decay=0.0)
self.proxycloss = ProxyPLoss(num_classes=num_classes, scale=self.scale)
def _initialize_weights(self, modules):
for m in modules:
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def update(self, x, y, **kwargs):
all_x = x
all_y = y
rep, pred = self.predict(all_x)
loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y)
fc_proj = F.linear(self.classifier, self.fc_proj)
assert fc_proj.requires_grad == True
loss_pcl = self.proxycloss(rep, all_y, fc_proj)
loss = loss_cls + self.pcl_weights * loss_pcl
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {
"loss_cls": loss_cls.item(), "loss_pcl": loss_pcl.item()}
def predict(self, x):
x = self.featurizer(x)
x = self.encoder(x)
rep = self.fea_proj(x)
pred = F.linear(x, self.classifier)
return rep, pred
网络结构
接下来是主要的网络结构模块,这里是用的ResNet50进行图片的特征提取,然后用全连接层进行encoder
————————————————————
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
class Identity(nn.Module):
"""An identity layer"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class SqueezeLastTwo(nn.Module):
""" A module which squeezes the last two dimensions, ordinary squeeze can be a problem for batch size 1 """
def __init__(self):
super(SqueezeLastTwo, self).__init__()
def forward(self, x):
return x.view(x.shape[0], x.shape[1])
class ResNet(torch.nn.Module):
"""ResNet with the softmax chopped off and the batchnorm frozen"""
def __init__(self):
super(ResNet, self).__init__()
#如果要用其他的网络进行特征提取,可以在这里改
#但是要把下面encoder模块的全连接的层的输入和新的网络最后的全连接输出相同
network = torchvision.models.resnet50(pretrained=False)
# network = resnet50(pretrained=hparams["pretrained"])
self.network = network
# adapt number of channels
# save memory
# del self.network.fc
#把新的网络的输出层替换为空,用来提供encoder的接口
#tips;最后一层大部分都是model.fc或model.head
self.network.fc = Identity()
self.dropout = nn.Dropout(0.1)
self.freeze_bn()
def forward(self, x):
"""Encode x into a feature vector of size n_outputs."""
return self.dropout(self.network(x))
def train(self, mode=True):
""" Override the default train() to freeze the BN parameters """
super().train(mode)
self.freeze_bn()
def freeze_bn(self):
for m in self.network.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def encoder():
scale_weights = 12
pcl_weights = 1
dropout = nn.Dropout(0.25)
hidden_size = 512
out_dim = 256
#换了新网络要注意改这里
n_outputs = 2048
encoder = nn.Sequential(
nn.Linear(n_outputs, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
dropout,
nn.Linear(hidden_size, out_dim),
)
return encoder, scale_weights, pcl_weights
def fea_proj():
dropout = nn.Dropout(0.25)
hidden_size = 256
out_dim = 256
fea_proj = nn.Sequential(
nn.Linear(out_dim,
out_dim),
)
fc_proj = nn.Parameter(
torch.FloatTensor(out_dim,
out_dim)
)
return fea_proj, fc_proj
损失函数
PCL中提到的损失函数
————————————————————————
losses.py
# coding: utf-8
''' custom loss function '''
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# # ========================= proxy Contrastive loss ==========================
class ProxyLoss(nn.Module):
''' pass '''
def __init__(self, scale=1, thres=0.1):
super(ProxyLoss, self).__init__()
self.scale = scale
self.thres = thres
def forward(self, feature, pred, target):
feature = F.normalize(feature, p=2, dim=1) # normalize
feature = torch.matmul(feature, feature.transpose(1, 0)) # (B, B)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
feature = feature * ~label_matrix # get negative matrix
feature = feature.masked_fill(feature < self.thres, -np.inf)
pred = torch.cat([pred, feature], dim=1) # (N, C+N)
loss = F.nll_loss(F.log_softmax(self.scale * pred, dim=1), \
target)
return loss
class ProxyPLoss(nn.Module):
''' pass '''
def __init__(self, num_classes, scale):
super(ProxyPLoss, self).__init__()
self.soft_plus = nn.Softplus()
self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
self.scale = scale
def forward(self, feature, target, proxy):
feature = F.normalize(feature, p=2, dim=1)
pred = F.linear(feature, F.normalize(proxy, p=2, dim=1)) # (N, C)
label = (self.label.unsqueeze(1) == target.unsqueeze(0))
pred = torch.masked_select(pred.transpose(1, 0), label)
pred = pred.unsqueeze(1)
feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
index_label = torch.LongTensor([i for i in range(feature.shape[0])]) # generate index label
index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0) # get index matrix
feature = feature * ~label_matrix # get negative matrix
feature = feature.masked_fill(feature < 1e-6, -np.inf)
logits = torch.cat([pred, feature], dim=1) # (N, C+N)
label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
return loss
class PosAlign(nn.Module):
''' pass '''
def __init__(self):
super(PosAlign, self).__init__()
self.soft_plus = nn.Softplus()
def forward(self, feature, target):
feature = F.normalize(feature, p=2, dim=1)
feature = torch.matmul(feature, feature.transpose(1, 0)) # (N, N)
label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
positive_pair = torch.masked_select(feature, label_matrix)
# print("positive_pair.shape", positive_pair.shape)
loss = 1. * self.soft_plus(torch.logsumexp(positive_pair, 0))
return loss
SWAD调参
如果要使用SWAD调参,可以从去GitHub自取swad和swa_utils。
以后有时间再逐行分析吧…
边栏推荐
- 如何成长为高级工程师?
- 【论文阅读】SPANET: SPATIAL PYRAMID ATTENTION NETWORK FOR ENHANCED IMAGE RECOGNITION
- [CV-Learning] Convolutional Neural Network Preliminary Knowledge
- 【论文阅读】Further Non-local and Channel Attention Networks for Vehicle Re-identification
- 【论文阅读】Exploring Spatial Significance via Hybrid Pyramidal Graph Network for Vehicle Re-identificatio
- 光条中心提取方法总结(一)
- Vision Transformer 论文 + 详解( ViT )
- tensorRT教程——tensor RT OP理解(实现自定义层,搭建网络)
- YOLOV5 V6.1 详细训练方法
- TypeError: load() missing 1 required positional argument: ‘Loader‘
猜你喜欢
![[CV-Learning] Semantic Segmentation](/img/ad/ff5076495fa68e4bbf3be78f5ac6f2.png)
[CV-Learning] Semantic Segmentation

Qt日常学习

【CV-Learning】Object Detection & Instance Segmentation

Dictionary feature extraction, text feature extraction.

Th in thymeleaf: href use notes

Pytorch语义分割理解

tensorRT教程——tensor RT OP理解(实现自定义层,搭建网络)

TensorFlow2 study notes: 5. Common activation functions

【论文阅读】Mining Cross-Image Semantics for Weakly Supervised Semantic Segmentation

PyTorch
随机推荐
Lee‘s way of Deep Learning 深度学习笔记
Use of double pointers
MAE 论文《Masked Autoencoders Are Scalable Vision Learners》
Copy攻城狮5分钟在线体验 MindIR 格式模型生成
Pytorch问题总结
Halcon缺陷检测
Dictionary feature extraction, text feature extraction.
MFC读取点云,只能正常显示第一个,显示后面时报错
Attention Is All You Need(Transformer)
tensorRT5.15 使用中的注意点
【论文阅读】Exploring Spatial Significance via Hybrid Pyramidal Graph Network for Vehicle Re-identificatio
详解近端策略优化
【CV-Learning】Image Classification
Pytorch语义分割理解
动手学深度学习_卷积神经网络CNN
【CV-Learning】图像分类
Copy攻城狮信手”粘“来 AI 对对联
YOLOV5 V6.1 详细训练方法
Install dlib step pit record, error: WARNING: pip is configured with locations that require TLS/SSL
ConnectionRefusedError: [Errno 111] Connection refused问题解决