SimSiam-pytorch
A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR) group as a self-supervised learning approach that omits the need for negative samples SimCLR, online clustring SwaV and momentum encoder BYOL.
Usage
In this implementation example, the original hyper-parameters specified by the original paper are set. Feel free to play with other hyper-parameters:
from torchvision.models import resnet18
model = resnet18()
learner = SimSiam(model)
opt = torch.optim.Adam(learner.parameters(), lr=0.001)
criterion = NegativeCosineSimilarity()
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images1 = sample_unlabelled_images()
images2 = images1*0.9
p1, p2, z1, z2 = learner(images1, images2).values()
loss = criterion(p1, p2, z1, z2)
opt.zero_grad()
loss.backward()
opt.step()
print(_+1,loss)
To do
- Build and test the original architecture
- add description for each component of the architecture
- model building with pytorch lightning
Citation
@inproceedings{chen2021exploring,
title={Exploring simple siamese representation learning},
author={Chen, Xinlei and He, Kaiming},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={15750--15758},
year={2021}
}