A Practical Guide to Contrastive Learning

-

Now it’s time for some contrastive learning. To mitigate the difficulty of insufficient annotation labels and fully utilize the massive quantity of unlabelled data, contrastive learning may very well be used to effectively help the backbone learn the information representations with out a specific task. The backbone may very well be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to attain satisfactory results.

Probably the most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.

SimCLR calculates over positive and negative pairs throughout the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a big batch size. SimCLR also requires the LARS optimizer to accommodate a big batch size.

SimSiam, nonetheless, uses a Siamese architecture, which avoids using negative pairs and further avoids the necessity for giant batch sizes. The differences between SimSiam and SimCLR are given within the table below.

Comparison between SimCLR and SimSiam. Image by writer.
The SimSiam architecture. Image source: https://arxiv.org/pdf/2011.10566

We are able to see from the figure above that the SimSiam architecture only comprises two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.

So, how can we implement this architecture in point of fact? Continuing on the supervised classification design, we keep the backbone the identical and only modify the MLP layer. Within the supervised learning architecture, the MLP outputs a 10-element vector indicating the chances of the ten classes. But for SimSiam, the aim just isn’t to perform “classification” but to learn the “representation,” so we’d like the output to be of the identical dimension because the backbone output for loss calculation. And the negative_cosine_similarity is given below:

import torch.nn as nn
import matplotlib.pyplot as plt

class SimSiam(nn.Module):

def __init__(self):

super(SimSiam, self).__init__()

self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)

self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 128*4*4),
)

def forward(self, x):
x = self.backbone(x)

x = x.view(-1, 128 * 4 * 4)
pred_output = self.prediction_mlp(x)
return x, pred_output

cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def negative_cosine_similarity_stopgradient(pred, proj):
return -cos(pred, proj.detach()).mean()

The pseudo-code for training the SimSiam is given in the unique paper below:

Training pseudo-code for SimSiam. Source: https://arxiv.org/pdf/2011.10566

And we convert it into real training code:

import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import RandAugment

import wandb

wandb_config = {
"learning_rate": 0.0001,
"architecture": "simsiam",
"dataset": "FashionMNIST",
"epochs": 100,
"batch_size": 256,
}

wandb.init(
# set the wandb project where this run can be logged
project="simsiam",
# track hyperparameters and run metadata
config=wandb_config,
)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

simsiam = SimSiam()

random_augmenter = RandAugment(num_ops=5)

optimizer = optim.SGD(simsiam.parameters(),
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)

train_dataloader = DataLoader(train_dataset, batch_size=wandb_config["batch_size"], shuffle=True)

# Training loop
for epoch in range(wandb_config["epochs"]):
simsiam.train()

print(f"Epoch {epoch}")
train_loss = 0
for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()

aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0,
random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0

proj1, pred1 = simsiam(aug1)
proj2, pred2 = simsiam(aug2)

loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2
loss.backward()
optimizer.step()

wandb.log({"training loss": loss})

if (epoch+1) % 10 == 0:
torch.save(simsiam.state_dict(), f"weights/simsiam_epoch{epoch+1}.pt")

We trained for 100 epochs as a good comparison to the limited supervised training; the training loss is shown below. Note: Resulting from its Siamese design, SimSiam may very well be very sensitive to hyperparameters like learning rate and MLP hidden layers. The unique SimSiam paper provides an in depth configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.

Training loss for SimSiam. Image by writer.

Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:

import tqdm
import numpy as np

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

simsiam = SimSiam()

test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
simsiam.load_state_dict(torch.load("weights/simsiam_epoch100.pt"))

simsiam.eval()
simsiam.to(device)

features = []
labels = []
for batch_idx, (image, goal) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):

with torch.no_grad():

proj, pred = simsiam(image.to(device))

features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())
labels.extend(goal.detach().cpu().numpy().tolist())

import plotly.express as px
import umap.umap_ as umap

reducer = umap.UMAP(n_components=3, n_neighbors=10, metric="cosine")
projections = reducer.fit_transform(np.array(features))

px.scatter(projections, x=0, y=1,
color=labels, labels={'color': 'Fashion MNIST Labels'}
)

The UMAP of the SimSiam representation over the testing set. Image by writer.

It’s interesting to see that there are two small islands within the reduced-dimension map above: class 5, 7, 8, and a few 9. If we pull out the FashionMNIST class list, we all know that these classes correspond to footwear reminiscent of “Sandal,” “Sneaker,” “Bag,” and “Ankle boot.” The massive purple cluster corresponds to clothing classes like “T-shirt/top,” “Trousers,” “Pullover,” “Dress,” “Coat,” and “Shirt.” The SimSiam demonstrates learning a meaningful representation within the vision domain.

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x