torch.tensor.gather用法
2023-12-14 04:18:38
介绍
官网介绍:https://pytorch.org/docs/stable/generated/torch.Tensor.gather.html
torch.Tensor.gather
是PyTorch中的一个函数,它根据索引从输入张量中收集值。
示例代码1
以下是一个使用torch.Tensor.gather
的示例:
import torch
# 创建一个输入张量
input = torch.tensor([[1, 2], [3, 4]])
# 创建一个索引张量
index = torch.tensor([[0, 0], [1, 0]])
# 使用gather函数
output = input.gather(1, index)
print(output)
# tensor([[1, 1],
# [4, 3]])
在上述代码中,input.gather(1, index)
会沿着维度1(列)收集值。索引张量index
中的每个值指定了在相应位置收集哪个元素。
input
中[1, 2]
的索引为[0, 1]
,index[0, 0]
用两个0索引拿到的值为[1, 1]
;同理,input
中[3, 4]
的索引为[0, 1]
,index[1, 0]
用两个0索引拿到的值为[4, 3]
;因此,output
的值为[[1, 1], [4, 3]]
。
示例代码2
以下是一个使用torch.Tensor.gather
的3维tensor示例:
import torch
# 创建一个输入张量
input = torch.arange(0,8).view(2, 2, 2)
# tensor([[[0, 1],
# [2, 3]],
# [[4, 5],
# [6, 7]]])
# 创建一个索引张量
index = torch.tensor([[[0,0]],[[1,0]]])
# 使用gather函数
output = input.gather(1, index)
print(output)
# tensor([[[0, 1]],
# [[6, 5]]])
# 错误示例
index = torch.tensor([[[0,0]],[[10,0]]])
input.gather(1, index) # RuntimeError: index 10 is out of bounds for dimension 1 with size 2
文章来源:https://blog.csdn.net/qq_36892712/article/details/134889644
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!