羔迪 发表于 2025-9-25 10:52:50

gather算子大不同

技术背景

在MindSpore和PyTorch框架中都有关于gather算子的实现。其实gather算子就是根据张量的索引,在指定维度下提取元素。但是gather算子在两个框架中的实现又有所不同,本文用几个示例来展开介绍一下。
MindSpore示例

在MindSpore中的gather实现,可以支持多维度的index:
In : import mindspore as ms

In : arr = ms.numpy.ones((1,27), dtype=ms.float32)

In : idx = ms.numpy.zeros((27,26), dtype=ms.int32)

In : res = ms.ops.gather(arr, idx, -1)

In : res.shape
Out: (1, 27, 26)这里index的维度是(27,26),而我们去索引的只有-1这个维度。其实在MindSpore中应该是通过内部实现,把index展平之后进行索引在做一个reshape,而这些内容都不在用户层去操作。
PyTorch示例

在PyTorch中,index维度必须跟输入的维度数量一致,否则就会发生RuntimeError,例如:
In : import torch as tc

In : arr = tc.ones((1,27),dtype=tc.float32)

In : idx = tc.zeros((27,26),dtype=tc.int64)

In : res = tc.gather(arr, -1, idx)
--------------------------------------------------------------------------
RuntimeError                           Traceback (most recent call last)
Cell In, line 1
----> 1 res = tc.gather(arr, -1, idx)

RuntimeError: Size does not match at dimension 0 expected index to be smaller than self apart from dimension 1

In : res = tc.gather(arr,-1,idx.reshape((1,-1))).reshape(idx.shape)

In : res.shape
Out: torch.Size()在这个示例中我们可以看到,在PyTorch里面不能支持跟输入的维度数量不同的索引张量,要进行手动的展开,最后再手动的reshape回去。
总结概要

本文通过2个实际的案例,演示了一下gather算子在MindSpore框架下PyTorch框架下的异同点。两者的输入都是tensor-axis-index,一个是输入顺序上略有区别,另一个是对于输入的张量索引维度的要求。在PyTorch中,如果我们要实现类似于MindSpore中的gather功能,需要手动对输入索引的维度操作一下。
版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/gather-ops.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

宇文之 发表于 7 天前

谢谢楼主提供!
页: [1]
查看完整版本: gather算子大不同