Home Artificial Intelligence Advanced Selection from Tensors in Pytorch

Advanced Selection from Tensors in Pytorch

1
Advanced Selection from Tensors in Pytorch

Using torch.index_select, torch.gather and torch.take

In some situations, you’ll have to do some advanced indexing / selection with Pytorch, e.g. answer the query: “how can I choose elements from Tensor A following the indices laid out in Tensor B?”

On this post we’ll present the three commonest methods for such tasks, namely torch.index_select, torch.gather and torch.take. We’ll explain all of them intimately and contrast them with each other.

Foto von Jerin J auf Unsplash

Admittedly, one motivation for this post was me forgetting how and when to make use of which function, ending up googling, browsing Stack Overflow and the, in my view, relatively temporary and never too helpful official documentation. Thus, as mentioned, we here do a deep dive into these functions: we motivate when to make use of which, give examples in 2- and 3D, and show the resulting selection graphically.

I hope this post will bring clarity about said functions and take away the necessity for further exploration — thanks for reading!

And now, without further ado, let’s dive into the functions one after the other. For all, we first start with a 2D example and visualize the resulting selection, after which move to somewhat more complex example in 3D. Further, we re-implement the executed operation in easy Python — s.t. you’ll be able to have a look at pseudocode as one other source of data what these functions do. In the long run, we summarize the functions and their differences in a table.

torch.index_select selects elements along one dimension, while keeping the opposite ones unchanged. That’s: keep all elements from all other dimensions, but pick elements within the goal dimensions following the index tensor. Let’s display this with a 2D example, by which we select along dimension 1:

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)

The resulting tensor has shape [len_dim_0, num_picks]: for each element along dimension 0, we’ve picked the identical element from dimension 1. Let’s visualize this:

1 COMMENT

  1. Although I believe every thought you have for your post is excellent and will undoubtedly be successful, the postings are too brief for new readers. Maybe you could extend them a little bit the next time? I’m grateful for the post.

LEAVE A REPLY

Please enter your comment!
Please enter your name here