找回密码
 立即注册
首页 业界区 安全 PyTorch中的take_along_dim

PyTorch中的take_along_dim

揉幽递 13 小时前
技术背景

在此前的一篇博客中,我们介绍过take_along_axis这个算子的具体使用方法。这里针对于Pytorch的take_along_dim算子,再重新介绍一次。
Numpy版本使用

这里我们展示的案例是基于numpy-2.0.1版本实现的:
  1. $ python3 -m pip show numpy
  2. Name: numpy
  3. Version: 2.0.1
  4. Summary: Fundamental package for array computing in Python
  5. Home-page: https://numpy.org
  6. Author: Travis E. Oliphant et al.
复制代码
示例如下:
  1. In [1]: import numpy as np
  2. In [2]: a = np.arange(12).reshape((1,4,3))
  3. In [3]: a
  4. Out[3]:
  5. array([[[ 0,  1,  2],
  6.         [ 3,  4,  5],
  7.         [ 6,  7,  8],
  8.         [ 9, 10, 11]]])
  9. In [4]: idx = np.array([1,2])
  10. In [6]: b = np.take_along_axis(a, idx[None,:,None], axis=1)
  11. In [7]: b
  12. Out[7]:
  13. array([[[3, 4, 5],
  14.         [6, 7, 8]]])
  15. In [8]: b = np.take_along_axis(a, idx[None,None,:], axis=2)
  16. In [9]: b
  17. Out[9]:
  18. array([[[ 1,  2],
  19.         [ 4,  5],
  20.         [ 7,  8],
  21.         [10, 11]]])
复制代码
在这个基础示例中,我们分别展示了同一个索引矩阵,在不同的维度上进行索引的结果。使用take_along_axis有一个默认的要求:原始数组和索引数组的维度数量需要保持一致。但是因为这里的索引矩阵是一维的,那么我们只要用slice的方法对索引矩阵进行扩维就好了。例如,我们需要在第二个维度进行提取,那么就可以用arr[None,:,None]来进行扩维。
PyTorch版实现

这里我们使用的torch是2.5.1的稳定版:
  1. $ python3 -m pip show torch
  2. Name: torch
  3. Version: 2.5.1
  4. Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
  5. Home-page: https://pytorch.org/
  6. Author: PyTorch Team
  7. Author-email: packages@pytorch.org
  8. License: BSD-3-Clause
  9. Location: /miniconda3/envs/pytorch/lib/python3.9/site-packages
  10. Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
  11. Required-by: torchaudio, torchmetrics, torchvision
复制代码
相关的API接口文档如下:
       
1.png
其实实现起来跟numpy的操作是非常类似的:
  1. In [1]: import torch as tc
  2. In [2]: a = tc.arange(12).reshape((1,4,3))
  3. In [3]: idx = tc.tensor([1,2])
  4. In [4]: b = tc.take_along_dim(a, idx[None,:,None], dim=1)
  5. In [5]: b
  6. Out[5]:
  7. tensor([[[3, 4, 5],
  8.          [6, 7, 8]]])
  9. In [6]: b = tc.take_along_dim(a, idx[None,None,:], dim=2)
  10. In [7]: b
  11. Out[7]:
  12. tensor([[[ 1,  2],
  13.          [ 4,  5],
  14.          [ 7,  8],
  15.          [10, 11]]])
复制代码
可以说是基本一致。那么同样的,也是要做一个扩维的处理。唯一一个不同的地方就是,在torch中是take_along_dim而不是像numpy或者mindspore中的take_along_axis,在torch中用dim替代了axis,包括函数名称和传入的关键词参数。
总结概要

接前面一篇take_along_axis的文章,本文主要介绍在PyTorch框架下,功能基本一样的函数take_along_dim。二者除了命名和一些关键词参数不一致之外,用法是一样的。需要注意的是,两者都要求输入的数组和索引数组维度数量一致。在特定场景下,需要手动进行扩维。
版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/take_along_dim.html
作者ID:DechinPhy
更多原著文章请参考:https://www.cnblogs.com/dechinphy/
打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册