Advanced Selection from Tensors in Pytorch

-

Using torch.index_select, torch.gather and torch.take

In some situations, you’ll have to do some advanced indexing / selection with Pytorch, e.g. answer the query: “how can I choose elements from Tensor A following the indices laid out in Tensor B?”

On this post we’ll present the three commonest methods for such tasks, namely torch.index_select, torch.gather and torch.take. We’ll explain all of them intimately and contrast them with each other.

Foto von Jerin J auf Unsplash

Admittedly, one motivation for this post was me forgetting how and when to make use of which function, ending up googling, browsing Stack Overflow and the, in my view, relatively temporary and never too helpful official documentation. Thus, as mentioned, we here do a deep dive into these functions: we motivate when to make use of which, give examples in 2- and 3D, and show the resulting selection graphically.

I hope this post will bring clarity about said functions and take away the necessity for further exploration — thanks for reading!

And now, without further ado, let’s dive into the functions one after the other. For all, we first start with a 2D example and visualize the resulting selection, after which move to somewhat more complex example in 3D. Further, we re-implement the executed operation in easy Python — s.t. you’ll be able to have a look at pseudocode as one other source of data what these functions do. In the long run, we summarize the functions and their differences in a table.

torch.index_select selects elements along one dimension, while keeping the opposite ones unchanged. That’s: keep all elements from all other dimensions, but pick elements within the goal dimensions following the index tensor. Let’s display this with a 2D example, by which we select along dimension 1:

num_picks = 2

values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)

The resulting tensor has shape [len_dim_0, num_picks]: for each element along dimension 0, we’ve picked the identical element from dimension 1. Let’s visualize this:

ASK DUKE

What are your thoughts on this topic?
Let us know in the comments below.

1 COMMENT

0 0 votes
Article Rating
guest
1 Comment
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

1
0
Would love your thoughts, please comment.x
()
x