Source code for torchpack.nn.functional.index

import torch

__all__ = ['batched_index_select']


[docs]def batched_index_select(inputs: torch.Tensor, indices: torch.Tensor, dim: int) -> torch.Tensor: vsizes, esizes = [], [] for k, size in enumerate(inputs.shape): if k == 0: vsizes.append(size) esizes.append(-1) elif k == dim: vsizes.append(-1) esizes.append(-1) else: vsizes.append(1) esizes.append(size) indices = indices.view(vsizes).expand(esizes) outputs = torch.gather(inputs, dim, indices) return outputs