函数torch.bincount( )的用法

2023-12-18 17:38:46

torch.bincount()函数是PyTorch中的一个函数,用于计算一维整数张量中每个非负整数值出现的频次

函数的用法 :

torch.bincount(input, weights=None, minlength=0) → Tensor

参数:

  • input:输入的一维整数张量
  • weights(可选):与input张量相同形状的张量,用于为每个值指定权重
  • minlength(可选):输出张量的最小长度

返回值:一个具有长度为max(input) + 1的一维长整型张量,其中索引i处的值表示i在输入张量中出现的频次

函数说明:

  • torch.bincount()函数 统计输入张量中每个非负整数值的频次。它适用于整数类型的张量,如torch.int8torch.int16torch.int32torch.int64
  • 输入张量可以是CPU上的张量,也可以是CUDA张量(GPU上的张量)
  • 输出张量的长度是输入张量中的最大值加1,即max(input) + 1
  • 输出张量中的元素顺序与输入张量中的非负整数值顺序相同

?例如:

import torch

input = torch.tensor([1, 2, 3, 2, 1, 1])
counts = torch.bincount(input)
print(counts)  # 输出: tensor([0, 3, 2, 1])
rrint(counts[1:]) # 输出:tensor([3,2,1])

在上面示例中,有一个输入张量input,包含一些非负整数值, 通过调用torch.bincount(input) ,计算了每个值在输入张量中出现的频次,得到了张量counts, counts[0]为0,因为0在输入张量中没有出现;counts[1]为3,因为1在输入张量中出现了3次,以此类推

注意:

在使用torch.bincount()函数时,它会计算一维整数张量中每个非负整数值的频次,包括最小值到最大值之间的所有整数值,即使某些整数值在输入张量中没有出现

在上述的例子中,input是一维张量[1, 2, 3, 2, 1, 1], 虽然 0 在 input 中没有出现,但torch.bincount(input)仍会考虑到0的存在 ,输出结果为 tensor([0, 3, 2, 1]),其中索引0 表示 0 这个整数值在input中出现的次数为0次,索引1出现了3次,索引2出现了2次,索引3出现了1次

torch.bincount()的输出张量长度与输入张量中的最大整数值相关。对于输入张量 input = torch.tensor([1, 2, 3, 2, 1, 1]),它包含了整数值1、2和3,torch.bincount(input) 的输出张量将具有长度为4,对应索引0到索引3。具体来说,输出张量的长度由输入张量中的最大整数值加1决定

在这个例子中,最大整数值是3,因此输出张量的长度为4

如果确保输入张量中不包含0,可以通过对输出进行切片来忽略索引0的值

例如,counts[1:]表示忽略索引0后的部分,得到tensor([3, 2, 1])

还可以传入一个与输入张量相同形状的权重张量 weights,可以为每个值指定权重

weights = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
weighted_counts = torch.bincount(input, weights)
print(weighted_counts)  # 输出:tensor([0.0000, 1.2000, 0.6000, 0.3000])
# 对于0来说,没有出现就是0
# 对于1来说,出现了三次:第一次出现位置上对应的权重为0.1,第二次出现位置上对应的权重为0.5,第三次出现 # 位置上对应的权重为0.6,所以0.1+0.5+0.6=1.2
# 对2来说,出现两次:第一次出现位置对应的权重为0.2,第二次出现位置对应的权重为0.4,故0.2+0.4=0.6
# 对于3来说,出现了一次:第一次出现位置上对应的权重为0.3,所以为0.3

通过调用 torch.bincount(input, weights),计算了每个值在输入张量中出现的加权频次,得到了张量 weighted_counts

此外,可以通过设置 minlength 参数来指定输出张量的最小长度

minlength_counts = torch.bincount(input, minlength=5)
print(minlength_counts)  # 输出: tensor([0, 3, 2, 1, 0])

在上面的示例中,我们调用torch.bincount(input, minlength=5),将最小长度设置为5,得到了张量 minlength_counts,它的长度为5,包含了输入张量中每个非负整数值的频次

补充:对于numpy数组有 numpy.bincount( )函数的用法:?? numpy.bincount( )函数的用法-CSDN博客?可以参考博文对比理解

文章来源:https://blog.csdn.net/Kelly_Ai_Bai/article/details/135062284
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。