关于pytorch中的dim的理解
2023-12-13 12:02:29
今天碰到一个代码看起来很简单,但是细究原理又感觉好像不太通不太对劲,就是多维tensor数据的操作,比如:
y.sum(dim=2)
,乍一看很简单数据相加操作,但是仔细一想,这里在第3维度的数据到底是横向相加还是纵向相加,带着疑问实验几次就明白了。
首先给个完整的例子:
import torch
y = torch.tensor([
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
]
])
print(y.sum(dim=2))
这里的y.shape = (3, 2, 3),三个维度的数据,所以dim可以是0~2也可以是-1~-3。我们每个维度都进行操作一遍就清楚了。
- 当dim=0时,相当于有3个二维的向量进行相加,结果还是一个二维向量(对应位置相加):
y.shape = (3, 2, 3) —> y.shape = (2, 3) - 当dim=1时,相当于有2个一维的向量进行相加×3,结果是1个一维向量×3则还是一个二维向量:
y.shape = (3, 2, 3) —> y.shape = (3, 3) - 当dim=2时,相当于有3个数值进行相加×2×3,结果两个值组成一维向量,三个一维向量组成二维向量:
y.shape = (3, 2, 3) —> y.shape = (3, 2)
其他的数据操作也是这样类似的思想。
总结:从中可以看出只要对一个n维度的数据的其中一维进行操作的话,得到的结果会是n-1维的向量,shape则是去掉那一维的个数。
文章来源:https://blog.csdn.net/weixin_45354497/article/details/134965390
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!