Home Artificial Intelligence Implementing Soft Nearest Neighbor Loss in PyTorch

Implementing Soft Nearest Neighbor Loss in PyTorch

0
Implementing Soft Nearest Neighbor Loss in PyTorch

The category neighborhood of a dataset will be learned using soft nearest neighbor loss

In this text, we discuss easy methods to implement the soft nearest neighbor loss which we also talked about here.

Representation learning is the duty of learning essentially the most salient features in a given dataset by a deep neural network. It is generally an implicit task done in a supervised learning paradigm, and it is a vital consider the success of deep learning (Krizhevsky et al., 2012; He et al., 2016; Simonyan et al., 2014). In other words, representation learning automates the means of feature extraction. With this, we will use the learned representations for downstream tasks equivalent to classification, regression, and synthesis.

Figure 1. Illustration from SNNL (Frosst et al., 2019). By minimizing the soft nearest neighbor loss, the distances amongst class-similar data points (as indicated by their color) are minimized while the distances amongst class-different data points are maximized.

We also can influence how the learned representations are formed to cater specific use cases. Within the case of classification, the representations are primed to have data points from the identical class to flock together, while for generation (e.g. in GANs), the representations are primed to have points of real data flock with the synthesized ones.

In the identical sense, now we have enjoyed using principal components evaluation (PCA) to encode features for downstream tasks. Nevertheless, we should not have any class or label information in PCA-encoded representations, hence the performance on downstream tasks could also be further improved. We are able to improve the encoded representations by approximating the category or label information in it by learning the neighborhood structure of the dataset, i.e. which features are clustered together, and such clusters would imply that the features belong to the identical class as per the clustering assumption within the semi-supervised learning literature (Chapelle et al., 2009).

To integrate the neighborhood structure within the representations, manifold learning techniques have been introduced equivalent to locally linear embeddings or LLE (Roweis & Saul, 2000), neighborhood components evaluation or NCA (Hinton et al., 2004), and t-stochastic neighbor embedding or t-SNE (Maaten & Hinton, 2008).

Nevertheless, the aforementioned manifold learning techniques have their very own drawbacks. As an illustration, each LLE and NCA encode linear embeddings as an alternative of nonlinear embeddings. Meanwhile, t-SNE embeddings result to different structures depending on the hyperparameters used.

To avoid such drawbacks, we will use an improved NCA algorithm which is the soft nearest neighbor loss or SNNL (Salakhutdinov & Hinton, 2007; Frosst et al., 2019). The SNNL improves the NCA algorithm by introducing nonlinearity, and it’s computed for every hidden layer of a neural network as an alternative of solely on the last encoding layer. This loss function is used to optimize the entanglement of points in a dataset.

On this context, entanglement is defined as how close class-similar data points to one another are in comparison with class-different data points. A low entanglement signifies that class-similar data points are much closer to every aside from class-different data points (see Figure 1). Having such a set of information points will render downstream tasks much easier to perform with an excellent higher performance. Frosst et al. (2019) expanded the SNNL objective by introducing a temperature factor T. Thus giving us the next as the ultimate loss function,

Figure 2. The soft nearest neighbor loss function. Figure by the writer.

where d is a distance metric on either raw input features or hidden layer representations of a neural network, and T is the temperature factor that’s directly proportional to the distances amongst data points in a hidden layer. For this implementation, we use the cosine distance as our distance metric for more stable computations.

Figure 3. The cosine distance formula. Figure by the writer.

The aim of this text is to assist readers understand and implement the soft nearest neighbor loss, and so we will dissect the loss function with a view to understand it higher.

Distance Metric

The very first thing we must always compute are the distances amongst data points, which can be either the raw input features or hidden layer representations of the network.

Figure 4. Step one in computing SNNL is to compute the space metric for the input data points. Figure by the writer.

For our implementation, we use the cosine distance metric (Figure 3) for more stable computations. On the time being, allow us to ignore the denoted subsets ij and ik within the figure above, and allow us to just give attention to computing the cosine distance amongst our input data points. We accomplish this through the next PyTorch code:

normalized_a = torch.nn.functional.normalize(features, dim=1, p=2)
normalized_b = torch.nn.functional.normalize(features, dim=1, p=2)
normalized_b = torch.conj(normalized_b).T
product = torch.matmul(normalized_a, normalized_b)
distance_matrix = torch.sub(torch.tensor(1.0), product)

Within the code snippet above, we first normalize the input features in lines 1 and a couple of using Euclidean norm. Then in line 3, we get the conjugate transpose of the second set of the normalized input features. We compute the conjugate transpose to account for complex vectors. In lines 4 and 5, we compute the cosine similarity and distance of the input features.

Concretely, consider the next set of features,

tensor([[ 1.0999, -0.9438,  0.7996, -0.4247],
[ 1.2150, -0.2953, 0.0417, -1.2913],
[ 1.3218, 0.4214, -0.1541, 0.0961],
[-0.7253, 1.1685, -0.1070, 1.3683]])

Using the space metric we defined above, we gain the next distance matrix,

tensor([[ 0.0000e+00,  2.8502e-01,  6.2687e-01,  1.7732e+00],
[ 2.8502e-01, 0.0000e+00, 4.6293e-01, 1.8581e+00],
[ 6.2687e-01, 4.6293e-01, -1.1921e-07, 1.1171e+00],
[ 1.7732e+00, 1.8581e+00, 1.1171e+00, -1.1921e-07]])

Sampling Probability

We are able to now compute the matrix that represents the probability of picking each feature given its pairwise distances to all other features. This is solely the probability of picking i points based on the distances between i and j or k points.

Figure 5. The second step is to compute the sampling probability of picking points based on their distances. Figure by the writer.

We are able to compute this through the next code:

pairwise_distance_matrix = torch.exp(
-(distance_matrix / temperature)
) - torch.eye(features.shape[0]).to(model.device)

The code first calculates the exponential of the negative of the space matrix divided by the temperature factor, scaling the values to positive values. The temperature factor dictates easy methods to control the importance given to the distances between pairs of points, as an illustration, at low temperatures, the loss is dominated by small distances while actual distances between widely separated representations develop into less relevant.

Prior to the subtraction of torch.eye(features.shape[0]) (aka diagonal matrix), the tensor was as follows,

tensor([[1.0000, 0.7520, 0.5343, 0.1698],
[0.7520, 1.0000, 0.6294, 0.1560],
[0.5343, 0.6294, 1.0000, 0.3272],
[0.1698, 0.1560, 0.3272, 1.0000]])

We subtract a diagonal matrix from the space matrix to remove all self-similarity terms (i.e. the space or similarity of every point to itself).

Next, we will compute the sampling probability for every pair of information points through the next code:

pick_probability = pairwise_distance_matrix / (
torch.sum(pairwise_distance_matrix, 1).view(-1, 1)
+ stability_epsilon
)

Masked Sampling Probability

To date, the sampling probability now we have computed doesn’t contain any label information. We integrate the label information into the sampling probability by masking it with the dataset labels.

Figure 6. We use the label information to isolate the possibilities for points belonging to the identical class. Figure by the writer.

First, now we have to derive a pairwise matrix out of the label vectors:

masking_matrix = torch.squeeze(
torch.eq(labels, labels.unsqueeze(1)).float()
)

We apply the masking matrix to make use of the label information to isolate the possibilities for points that belong to the identical class:

masked_pick_probability = pick_probability * masking_matrix

Next, we compute the sum probability for sampling a selected feature by computing the sum of the masked sampling probability per row,

summed_masked_pick_probability = torch.sum(masked_pick_probability, dim=1)

Finally, we will compute the logarithm of the sum of the sampling probabilities for features for computational convenience with a further computational stability variable, and get the common to act as the closest neighbor loss for the network,

snnl = torch.mean(
-torch.log(summed_masked_pick_probability + stability_epsilon
)

We are able to now string these components together in a forward pass function to compute the soft nearest neighbor loss across all layers of a deep neural network,

def forward(
self,
model: torch.nn.Module,
features: torch.Tensor,
labels: torch.Tensor,
outputs: torch.Tensor,
epoch: int,
) -> Tuple:
if self.use_annealing:
self.temperature = 1.0 / ((1.0 + epoch) ** 0.55)

primary_loss = self.primary_criterion(
outputs, features if self.unsupervised else labels
)

activations = self.compute_activations(model=model, features=features)

layers_snnl = []
for key, value in activations.items():
value = value[:, : self.code_units]
distance_matrix = self.pairwise_cosine_distance(features=value)
pairwise_distance_matrix = self.normalize_distance_matrix(
features=value, distance_matrix=distance_matrix
)
pick_probability = self.compute_sampling_probability(
pairwise_distance_matrix
)
summed_masked_pick_probability = self.mask_sampling_probability(
labels, pick_probability
)
snnl = torch.mean(
-torch.log(self.stability_epsilon + summed_masked_pick_probability)
)
layers_snnl.append(snnl)

snn_loss = torch.stack(layers_snnl).sum()

train_loss = torch.add(primary_loss, torch.mul(self.factor, snn_loss))

return train_loss, primary_loss, snn_loss

Visualizing Disentangled Representations

We trained an autoencoder with the soft nearest neighbor loss, and visualize its learned disentangled representations. The autoencoder had (x-500–500–2000-d-2000–500–500-x) units, and was trained on a small labelled subset of the MNIST, Fashion-MNIST, and EMNIST-Balanced datasets. That is to simulate the scarcity of labelled examples since autoencoders are presupposed to be unsupervised models.

Figure 7. 3D visualization comparing the unique representation and the disentangled latent representation of the three datasets. To attain this visualization, the representations were encoded using t-SNE with perplexity = 50 and learning rate = 10, optimized for five,000 iterations. Figure by the writer.

We only visualized an arbitrarily chosen 10 clusters for easier and cleaner visualization of the EMNIST-Balanced dataset. We are able to see within the figure above that the latent code representation became more clustering-friendly by having a set of well-defined clusters as indicated by cluster dispersion and proper cluster assignments as indicated by cluster colours.

Closing Remarks

In this text, we dissected the soft nearest neighbor loss function as to how we could implement it in PyTorch.

The soft nearest neighbor loss was first introduced by Salakhutdinov & Hinton (2007) where it was used to compute the loss on the latent code (bottleneck) representation of an autoencoder, after which the said representation was used for downstream kNN classification task.

Frosst, Papernot, & Hinton (2019) then expanded the soft nearest neighbor loss by introducing a temperature factor and by computing the loss across all layers of a neural network.

Finally, we employed an annealing temperature factor for the soft nearest neighbor loss to further improve the learned disentangled representations of a network, and likewise speed up the disentanglement process (Agarap & Azcarraga, 2020).

The complete code implementation is offered in GitLab.

References

  • Agarap, Abien Fred, and Arnulfo P. Azcarraga. “Improving k-means clustering performance with disentangled internal representations.” 2020 International Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
  • Chapelle, Olivier, Bernhard Scholkopf, and Alexander Zien. “Semi-supervised learning (chapelle, o. et al., eds.; 2006)[book reviews].” IEEE Transactions on Neural Networks 20.3 (2009): 542–542.
  • Frosst, Nicholas, Nicolas Papernot, and Geoffrey Hinton. “Analyzing and improving representations with the soft nearest neighbor loss.” International conference on machine learning. PMLR, 2019.
  • Goldberger, Jacob, et al. “Neighbourhood components evaluation.” Advances in neural information processing systems. 2005.
  • He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
  • Hinton, G., et al. “Neighborhood components evaluation.” Proc. NIPS. 2004.
  • Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. “Imagenet classification with deep convolutional neural networks.” Advances in neural information processing systems 25 (2012).
  • Roweis, Sam T., and Lawrence K. Saul. “Nonlinear dimensionality reduction by locally linear embedding.” science 290.5500 (2000): 2323–2326.
  • Salakhutdinov, Ruslan, and Geoff Hinton. “Learning a nonlinear embedding by preserving class neighbourhood structure.” Artificial Intelligence and Statistics. 2007.
  • Simonyan, Karen, and Andrew Zisserman. “Very deep convolutional networks for large-scale image recognition.” arXiv preprint arXiv:1409.1556 (2014).
  • Van der Maaten, Laurens, and Geoffrey Hinton. “Visualizing data using t-SNE.” Journal of machine learning research 9.11 (2008).

LEAVE A REPLY

Please enter your comment!
Please enter your name here