16.3 C
New York

A Practical Guide to Contrastive Learning | by Mengliu Zhao | Jul, 2024

Now it’s time for some contrastive learning. To mitigate the issue of insufficient annotation labels and fully utilize the large quantity of unlabelled data, contrastive learning could be used to effectively help the backbone learn the data representations without a specific task. The backbone could be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to achieve satisfactory results.The most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.SimCLR calculates over positive and negative pairs within the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a large batch size. SimCLR also requires the LARS optimizer to accommodate a large batch size.SimSiam, however, uses a Siamese architecture, which avoids using negative pairs and further avoids the need for large batch sizes. The differences between SimSiam and SimCLR are given in the table below.Comparison between SimCLR and SimSiam. Image by author.The SimSiam architecture. Image source: https://arxiv.org/pdf/2011.10566We can see from the figure above that the SimSiam architecture only contains two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.So, how do we implement this architecture in reality? Continuing on the supervised classification design, we keep the backbone the same and only modify the MLP layer. In the supervised learning architecture, the MLP outputs a 10-element vector indicating the probabilities of the 10 classes. But for SimSiam, the purpose is not to perform “classification” but to learn the “representation,” so we need the output to be of the same dimension as the backbone output for loss calculation. And the negative_cosine_similarity is given below:import torch.nn as nnimport matplotlib.pyplot as pltclass SimSiam(nn.Module):def __init__(self):super(SimSiam, self).__init__()self.backbone = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),nn.ReLU(),nn.BatchNorm2d(128),)self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),nn.BatchNorm1d(64),nn.ReLU(),nn.Linear(64, 128*4*4),)def forward(self, x):x = self.backbone(x)x = x.view(-1, 128 * 4 * 4)pred_output = self.prediction_mlp(x)return x, pred_outputcos = nn.CosineSimilarity(dim=1, eps=1e-6)def negative_cosine_similarity_stopgradient(pred, proj):return -cos(pred, proj.detach()).mean()The pseudo-code for training the SimSiam is given in the original paper below:Training pseudo-code for SimSiam. Source: https://arxiv.org/pdf/2011.10566And we convert it into real training code:import tqdmimport torchimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom torchvision.transforms import RandAugmentimport wandbwandb_config = {“learning_rate”: 0.0001,”architecture”: “simsiam”,”dataset”: “FashionMNIST”,”epochs”: 100,”batch_size”: 256,}wandb.init(# set the wandb project where this run will be loggedproject=”simsiam”,# track hyperparameters and run metadataconfig=wandb_config,)# Initialize model and optimizerdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)simsiam = SimSiam()random_augmenter = RandAugment(num_ops=5)optimizer = optim.SGD(simsiam.parameters(), lr=wandb_config[“learning_rate”], momentum=0.9, weight_decay=1e-5,)train_dataloader = DataLoader(train_dataset, batch_size=wandb_config[“batch_size”], shuffle=True)# Training loopfor epoch in range(wandb_config[“epochs”]):simsiam.train()print(f”Epoch {epoch}”)train_loss = 0for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):optimizer.zero_grad()aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0, \random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0proj1, pred1 = simsiam(aug1)proj2, pred2 = simsiam(aug2)loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2loss.backward()optimizer.step()wandb.log({“training loss”: loss})if (epoch+1) % 10 == 0:torch.save(simsiam.state_dict(), f”weights/simsiam_epoch{epoch+1}.pt”)We trained for 100 epochs as a fair comparison to the limited supervised training; the training loss is shown below. Note: Due to its Siamese design, SimSiam could be very sensitive to hyperparameters like learning rate and MLP hidden layers. The original SimSiam paper provides a detailed configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.Training loss for SimSiam. Image by author.Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:import tqdmimport numpy as npimport torchdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)simsiam = SimSiam() test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)simsiam.load_state_dict(torch.load(“weights/simsiam_epoch100.pt”))simsiam.eval()simsiam.to(device)features = []labels = []for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):with torch.no_grad():proj, pred = simsiam(image.to(device))features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())labels.extend(target.detach().cpu().numpy().tolist())import plotly.express as pximport umap.umap_ as umapreducer = umap.UMAP(n_components=3, n_neighbors=10, metric=”cosine”)projections = reducer.fit_transform(np.array(features))px.scatter(projections, x=0, y=1,color=labels, labels={‘color’: ‘Fashion MNIST Labels’})The UMAP of the SimSiam representation over the testing set. Image by author.It’s interesting to see that there are two small islands in the reduced-dimension map above: class 5, 7, 8, and some 9. If we pull out the FashionMNIST class list, we know that these classes correspond to footwear such as “Sandal,” “Sneaker,” “Bag,” and “Ankle boot.” The big purple cluster corresponds to clothing classes like “T-shirt/top,” “Trousers,” “Pullover,” “Dress,” “Coat,” and “Shirt.” The SimSiam demonstrates learning a meaningful representation in the vision domain.

Related articles

Recent articles