self-attention|李宏毅机器学习21年

2023-12-14 16:45:35

来源:https://www.bilibili.com/video/BV1Bb4y1L7FT?p=1&vd_source=f66cebc7ed6819c67fca9b4fa3785d39

self-attention要解决的问题:输入的sequence是变长的、长度不等。

引言

如何解决输入同样的saw,第一个输出v.第二个输出n.?
使用FC可以考虑上下文的资讯。

如何考虑一整个sequence的资讯呢?
把Windows开到sequence中最大的长度。
在这里插入图片描述

self-attention

在这里插入图片描述
可以将self-attention与FC交替使用:
self-attention处理整个句子的资讯
FC专注于处理某一个位置的资讯、
在这里插入图片描述

运作机制

在这里插入图片描述

b1是如何产生的

1、计算出attention score α \alpha α:在这个长长的sequence里找出和a1有关联的vector,每个向量与a1的关联性用数值 α \alpha α表示。
在这里插入图片描述
在这里插入图片描述
2、根据attention score抽取sequence里的重要资讯,即可计算出b1
在这里插入图片描述
注:b1-b4是同时被产生的

怎么求关联性数值 α \alpha α

两种方法:
在这里插入图片描述

最常用的是向量点积法,也是用在transformer里的方法。

从矩阵乘法的角度再来一次

从A得到Q、K、V

在这里插入图片描述

从Q、K得到 α \alpha α矩阵

在这里插入图片描述

由V和A’得到b1-b4

在这里插入图片描述

总结:从I到O就是在做self-attention

在这里插入图片描述

Muti-head Self-attention

几个head,是一个需要调的超参。
为什么要用Muti-head?
使用不同的q代表不同种类的相关性。
在这里插入图片描述
在这里插入图片描述

位置编码

在这里插入图片描述
举例:
假设我们想要为一个长度为 seq_length = 4 的序列生成位置编码,并且我们想要的编码维度是 d_model = 8。

初始化位置和维度索引矩阵:

位置矩阵 position (shape: [seq_length, 1]):

[[0],
[1],
[2],
[3]]
维度索引矩阵 i (shape: [1, d_model]):

[[0, 1, 2, 3, 4, 5, 6, 7]]
计算角速率:

使用公式 angle_rates = 1 / (10000^(2 * (i//2) / d_model)) 计算 angle_rates:

angle_rates = 1 / (10000^(2 * ([0, 1, 2, 3, 4, 5, 6, 7]//2) / 8))
angle_rates = 1 / (10000^(2 * [0, 0, 1, 1, 2, 2, 3, 3] / 8))
angle_rates = 1 / (10000^(0, 0, 0.25, 0.25, 0.5, 0.5, 0.75, 0.75))
假设我们计算后得到如下的 angle_rates (shape: [1, d_model]):

[[1.0, 1.0, 0.1778, 0.1778, 0.0316, 0.0316, 0.0056, 0.0056]]
计算角度值:

将 position 矩阵与 angle_rates 矩阵相乘得到 angle_rads:

angle_rads = position * angle_rates
假设我们得到如下的 angle_rads (shape: [seq_length, d_model]):

[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[1.0000, 1.0000, 0.1778, 0.1778, 0.0316, 0.0316, 0.0056, 0.0056],
[2.0000, 2.0000, 0.3556, 0.3556, 0.0632, 0.0632, 0.0112, 0.0112],
[3.0000, 3.0000, 0.5334, 0.5334, 0.0948, 0.0948, 0.0168, 0.0168]]
应用正弦和余弦函数:

对偶数索引应用正弦函数,对奇数索引应用余弦函数:

PE(pos, 2i) = sin(angle_rads[:, 2i])
PE(pos, 2i+1) = cos(angle_rads[:, 2i+1])
假设我们得到如下的位置编码 position_encoding (shape: [seq_length, d_model]):

[[0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000],
[0.8415, 0.5403, 0.1768, 0.9843, 0.0316, 0.9995, 0.0056, 0.9999],
[0.9093, -0.4161, 0.3484, 0.9373, 0.0629, 0.9980, 0.0112, 0.9997],
[0.1411, -0.9900, 0.5121, 0.8590, 0.0941, 0.9955, 0.0168, 0.9994]]

文章来源:https://blog.csdn.net/m0_57290240/article/details/134905031
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。