torch.where用法介绍

2024-01-03 05:56:04

torch.where用法1介绍

torch.where(condition, x, y) → Tensor

这个用法介绍直接搜就可以,不做介绍

torch.where用法2介绍(在yolov8中计算TP中出现)

torch.where(condition) → Tensor

返回的condition中为True的索引
直接举一个例子吧

condition = torch.tensor([[True, False, False], [False, True, False], [True, True, True]]) # torch.Size([3, 3])
torch.where(condition)
'''
(tensor([0, 1, 2, 2, 2]), tensor([0, 1, 0, 1, 2]))
'''

返回的是一个二维元祖,tensor([0, 1, 2, 2, 2])这个代表的是condition中的第一维所有的索引,tensor([0, 1, 0, 1, 2])这个代表的是condition第二维中所有的索引,例如,输出结果组合起来

[0, 1, 2, 2, 2]
[0, 1, 0, 1, 2]
一一组合起来
[(0, 0), (1, 1), (2, 0), (2, 1), (2, 2)]
可以看到
condition[0][0]=True, condition[1][1]=True, condition[2][0]=True, condition[2][1]=True, condition[2][2]=True
其余的都为False

文章来源:https://blog.csdn.net/shilichangtin/article/details/135335322
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。