python中[None, :]的用法

2023-12-19 22:30:40

看下面的position embedding的代码:

class LearnablePositionalEmbedding(torch.nn.Module):
    """Shorthand for a learnable embedding."""

    def __init__(self, embed_dim, max_position_embeddings=1024, dropout=0.0):
        super().__init__()
        self.embedding = torch.nn.Embedding(max_position_embeddings, embed_dim)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, input_embeddings):
        """This is a batch-first implementation"""
        position_ids = torch.arange(input_embeddings.shape[1], device=self.embedding.weight.device)
        position_embeddings = self.embedding(position_ids[None, :])
        return self.dropout(input_embeddings + position_embeddings)

简言之:

position_ids[None, :] 的目的是为了将其变成一个二维张量,以便与 input_embeddings 进行相加。在这里,position_ids 的长度(N)应该与 input_embeddings 张量的第二个维度长度相同

input_embeddings 的形状是 (batch_size, sequence_length, embed_dim),那么 position_ids[None, :] 的形状将变为 (1, sequence_length),然后通过广播(broadcasting)机制,它会与 input_embeddings 的第一个维度进行广播,使得两者的形状能够相加。这样,每个位置的嵌入都与相应的位置信息相加,从而引入了位置编码。?

文章来源:https://blog.csdn.net/vivi_cin/article/details/135089921
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。