Triplet Loss三元组损失函数

2023-12-21 09:50:43

基础知识

三元组损失(Triplet Loss)是一种用于学习深度神经网络嵌入的损失函数,它的主要目标是确保在我们的嵌入空间中,来自相同类别的样本更接近彼此,而不同类别的样本更远离彼此。三元组损失(Triplet Loss)常在人脸识别、图像检索等需要计算相似度的任务中使用

三元组损失需要三个样本来计算损失,这三个样本被称为锚(Anchor)、正(Positive)和负(Negative)样本。其中,锚样本是我们关注的样本,正样本与锚样本具有相同的类别标签,负样本与锚样本具有不同的类别标签。
假设我们已经通过神经网络得到了这三个样本在嵌入空间的位置,分别是 A(锚样本),P(正样本)和 N(负样本)。则三元组损失函数的形式为:
L = max(d(A, P) - d(A, N) + margin, 0)
其中,d(A, P) 和 d(A, N) 分别是锚样本与正样本,锚样本与负样本在嵌入空间的距离,"margin"是一个预设定的阈值,用于控制正样本与负样本之间的差异,我们希望锚样本比与负样本的距离比至少比与正样本的距离大。
例如:
我们有三个样本锚样本A, 正样本P, 负样本N。它们分别被一个神经网络映射到一个三维空间,得到的嵌入向量是:
A = [1, 1, 1]P = [1.1, 1.1, 1.1]
N = [2, 2, 2]
我们可以看到,正样本P比锚样本A更接近,而负样本N则比正样本P和锚样本A更远,这就是我们希望的结果。但如果网络没有很好的训练,可能会得到违背这一原则的嵌入,例如负样本N离锚样本A更近,那么这就需要三元组损失来调整网络的权重,使得同类样本更接近,不同类样本更远离。

Triplet Loss三元组损失函数 在模型训练中,batchsize不能设置太小:

  • 多样性:在一个Batch中,我们需要包含足够多的类别,以便从中选择出质量较好的三元组。如果Batch太小,可能只包含少量的类别,这将限制我们选择三元组的可能性。
  • 稳定性:较大的batch size可以使网络的训练更稳定。每个batch的梯度计算都是对全局梯度的一个估计,batch size越大,这个估计的准确性就越高,训练过程也就越稳定。

代码讲解

Triplet Loss三元组损失函数如下:

def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
    r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
    Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
    Loss for Person Re-Identification'."""
    if norm_feat:
        dist_mat = cosine_dist(embedding, embedding)
    else:
        dist_mat = euclidean_dist(embedding, embedding)

    # For distributed training, gather all features from different process.
    # if comm.get_world_size() > 1:
    #     all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
    #     all_targets = concat_all_gather(targets)
    # else:
    #     all_embedding = embedding
    #     all_targets = targets
    
    # 获取相似度矩阵dist_mat的行数,即样本数量
    N = dist_mat.size(0)
    # 创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)
    is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
    is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()

    if hard_mining:
        dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
    else:
        dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)

    y = dist_an.new().resize_as_(dist_an).fill_(1)

    if margin > 0:
        loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
    else:
        loss = F.soft_margin_loss(dist_an - dist_ap, y)
        # fmt: off
        if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
        # fmt: on

    return loss

对上面代码进行解析:

定义函数

def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):

定义了一个名为triplet_loss的函数,输入参数为embedding(嵌入特征)、targets(目标标签)、margin(用于增加正负样本之间间距的值)、norm_feat(决定是否对特征进行归一化)以及hard_mining(决定是否启动困难样本挖掘)。

数据归一化处理

    if norm_feat:
        dist_mat = cosine_dist(embedding, embedding)
    else:
        dist_mat = euclidean_dist(embedding, embedding)

判断是否对特征进行归一化,若决定归一化,就用余弦距离度量相似度;若不归一化,则用欧氏距离度量相似度。

cosine_dist(embedding, embedding)是将embedding中的每一个向量与embedding中的每一个向量都计算一遍余弦距离。

  • 举一个简单的例子:
假设你的embedding是一个(3, 2)的张量,内容如下:
[[a1, a2],
 [b1, b2],
 [c1, c2]]
其中,[a1, a2],[b1, b2]和[c1, c2]是这个embedding中的3个向量。
当你执行cosine_dist(embedding, embedding)时,实际上计算的是:
[[cosine_dist([a1, a2], [a1, a2]), cosine_dist([a1, a2], [b1, b2]), cosine_dist([a1, a2], [c1, c2])],
 [cosine_dist([b1, b2], [a1, a2]), cosine_dist([b1, b2], [b1, b2]), cosine_dist([b1, b2], [c1, c2])],
 [cosine_dist([c1, c2], [a1, a2]), cosine_dist([c1, c2], [b1, b2]), cosine_dist([c1, c2], [c1, c2])]]
这个结果是一个(3, 3)的矩阵,表示embedding中的每一个向量与embedding中的每一个向量之间的余弦距离。
当if norm_feat:这个条件语句为真时,即当我们想对embedding进行归一化处理时,就会使用这种方法计算embedding中所有向量之间的余弦距离。

矩阵is_pos和is_neg构建

    N = dist_mat.size(0)
    is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
    is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()

创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)。

  • 假设我们有4个样本,它们的类标签targets是[1, 2, 1, 2],矩阵的行和列分别代表样本的索引,而值则表示相对应的两个样本是否属于同一类别(is_pos)或不同类别(is_neg)。
targets.view(N, 1).expand(N, N),得到的结果是:
1 1 1 1
2 2 2 2
1 1 1 1
2 2 2 2
执行targets.view(N, 1).expand(N, N).t(),得到的结果是:
1 2 1 2
1 2 1 2
1 2 1 2
1 2 1 2
当我们用eq()去判断两个矩阵对应位置是否相等时,得到的结果(is_pos)是:
1 0 1 0
0 1 0 1
1 0 1 0
0 1 0 1
对应位置用ne()去判断是否不相等,得到的结果(is_neg)是:
0 1 0 1
1 0 1 0
0 1 0 1
1 0 1 0

样本挖掘

if hard_mining:
    dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
    dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)

根据是否进行困难样本挖掘,采用不同的挖掘方法获取到每个样本对的距离。

# 对于每个锚点样本,找到最难正样本(最远的具有相同类别标签的样本)和最难负样本(最近的具有不同类别标签的样本)。
def hard_example_mining(dist_mat, is_pos, is_neg):
    """For each anchor, find the hardest positive and negative sample.
    Args:
      dist_mat: pair wise distance between samples, shape [N, M]
      is_pos: positive index with shape [N, M]
      is_neg: negative index with shape [N, M]
    Returns:
      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
      p_inds: pytorch LongTensor, with shape [N];
        indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
      n_inds: pytorch LongTensor, with shape [N];
        indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
    NOTE: Only consider the case in which all labels have same num of samples,
      thus we can cope with all anchors in parallel.
    """

    assert len(dist_mat.size()) == 2

    # `dist_ap` means distance(anchor, positive)
    # both `dist_ap` and `relative_p_inds` with shape [N]
    # dist_ap表示锚点样本与正样本之间的距离。通过在距离矩阵和正样本矩阵做逐元素相乘后,取每行(每个锚点)的最大值。
    dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)
    # `dist_an` means distance(anchor, negative)
    # both `dist_an` and `relative_n_inds` with shape [N]
    # dist_an表示锚点样本与负样本之间的距离。首先,通过在距离矩阵和负样本矩阵做逐元素相乘后,再将正样本矩阵与大数(1e9)相乘并加到上述结果上,旨在将负样本对里的正样本对的距离设置地非常大。之后取每行的最小值,找出与锚点样本最近且类别不同的样本。
    dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)

    return dist_ap, dist_an


def weighted_example_mining(dist_mat, is_pos, is_neg):
    """For each anchor, find the weighted positive and negative sample.
    Args:
      dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
      is_pos:
      is_neg:
    Returns:
      dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
      dist_an: pytorch Variable, distance(anchor, negative); shape [N]
    """
    assert len(dist_mat.size()) == 2

    is_pos = is_pos
    is_neg = is_neg
    # 对于每个锚点样本,找到正样本和负样本的加权距离
    dist_ap = dist_mat * is_pos
    dist_an = dist_mat * is_neg
	
	# 分别通过softmax函数计算正样本和负样本的权重,注意负样本在计算权重之前要取负数。
    weights_ap = softmax_weights(dist_ap, is_pos)
    weights_an = softmax_weights(-dist_an, is_neg)
	
	# 计算的是加权距离,将距离与对应的权重相乘,然后对结果进行累加求和,得到最后的加权距离。
    dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
    dist_an = torch.sum(dist_an * weights_an, dim=1)

    return dist_ap, dist_an

loss计算

y = dist_an.new().resize_as_(dist_an).fill_(1)

创建一个和dist_an相同大小并内容全部为1的向量。

  • y在F.margin_ranking_loss函数中起到了标记的作用,决定了两个输入之间期望的相对大小和顺序。当我们设置y为1时,表示我们期望dist_an(锚点到负样本的距离)大于dist_ap(锚点到正样本的距离)。这也符合我们在训练过程中的期望:即我们希望模型将锚点与其类别内(正样本)的距离保持小,将其与其他类别(负样本)的距离保持大。
  • 如果y不设置为1,而是设置为-1,那么其含义将完全颠倒,此时,我们期望dist_an(锚点到负样本的距离)小于dist_ap(锚点到正样本的距离)。这显然违背了我们在进行特征学习时的初衷,无法良好地反映出同类间的聚合性和异类间的分离性。
    if margin > 0:
        loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
    else:
        loss = F.soft_margin_loss(dist_an - dist_ap, y)
        # fmt: off
        if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
        # fmt: on

计算最终的三元组损失:

  • 如果margin值大于0,那就使用margin ranking loss。这将试图确保正样本对的距离比负样本对的距离小于margin;
  • 如果margin值不大于0,那就使用soft margin loss,它是margin ranking loss的一个变体,其中margin被设置为0,并在损失函数中引入了一个logistic损失。在计算soft margin loss之后,如果得到的loss值为infinity,则将margin值手动设置为0.3,再次使用margin ranking loss计算损失。

F.margin_ranking_loss函数是用来实现三元组损失的一个实用方法,它接受两组数据和一个目标向量作为输入来计算定制的秩序损失。

dist_ap代表锚点和正样本之间的距离,最大距离;dist_an代表锚点和负样本之间的距离,最小距离。y是目标向量,经常被设置为1,表示我们希望dist_an(锚点和负样本之间的距离)比dist_ap(锚点和正样本之间的距离)大。margin是我们希望两者之间的最小差距。

文章来源:https://blog.csdn.net/qq_42178122/article/details/135101463
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。