找回密码
 立即注册
首页 业界区 安全 gather算子大不同

gather算子大不同

羔迪 2025-9-25 10:52:50
技术背景

在MindSpore和PyTorch框架中都有关于gather算子的实现。其实gather算子就是根据张量的索引,在指定维度下提取元素。但是gather算子在两个框架中的实现又有所不同,本文用几个示例来展开介绍一下。
MindSpore示例

在MindSpore中的gather实现,可以支持多维度的index:
  1. In [1]: import mindspore as ms
  2. In [2]: arr = ms.numpy.ones((1,27), dtype=ms.float32)
  3. In [3]: idx = ms.numpy.zeros((27,26), dtype=ms.int32)
  4. In [4]: res = ms.ops.gather(arr, idx, -1)
  5. In [5]: res.shape
  6. Out[5]: (1, 27, 26)
复制代码
这里index的维度是(27,26),而我们去索引的只有-1这个维度。其实在MindSpore中应该是通过内部实现,把index展平之后进行索引在做一个reshape,而这些内容都不在用户层去操作。
PyTorch示例

在PyTorch中,index维度必须跟输入的维度数量一致,否则就会发生RuntimeError,例如:
  1. In [1]: import torch as tc
  2. In [2]: arr = tc.ones((1,27),dtype=tc.float32)
  3. In [3]: idx = tc.zeros((27,26),dtype=tc.int64)
  4. In [4]: res = tc.gather(arr, -1, idx)
  5. --------------------------------------------------------------------------
  6. RuntimeError                             Traceback (most recent call last)
  7. Cell In[4], line 1
  8. ----> 1 res = tc.gather(arr, -1, idx)
  9. RuntimeError: Size does not match at dimension 0 expected index [27, 26] to be smaller than self [1, 27] apart from dimension 1
  10. In [5]: res = tc.gather(arr,-1,idx.reshape((1,-1))).reshape(idx.shape)
  11. In [6]: res.shape
  12. Out[6]: torch.Size([27, 26])
复制代码
在这个示例中我们可以看到,在PyTorch里面不能支持跟输入的维度数量不同的索引张量,要进行手动的展开,最后再手动的reshape回去。
总结概要

本文通过2个实际的案例,演示了一下gather算子在MindSpore框架下PyTorch框架下的异同点。两者的输入都是tensor-axis-index,一个是输入顺序上略有区别,另一个是对于输入的张量索引维度的要求。在PyTorch中,如果我们要实现类似于MindSpore中的gather功能,需要手动对输入索引的维度操作一下。
版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/gather-ops.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

4 天前

举报

谢谢楼主提供!
您需要登录后才可以回帖 登录 | 立即注册