AttentionFreeTransformer 核心结构图(GraphViz 重绘)
2024-01-10 13:35:33
AFTFull
digraph AFTFull {
rankdir=BT
node [
style=filled,
color=Black
fontcolor=White,
fillcolor="#30638e",
fontname="SimHei",
fontsize=32,
width=5, height=2,
]
inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"]
llk [label="LinearK\n[HidSize, ProjSize]", shape="box"]
llv [label="LinearV\n[HidSize, ProjSize]", shape="box"]
w [label="W:Param\n[SeqLen, SeqLen]", shape="Mrecord"]
q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
σ [label="Sigmoid", shape="box", width=3]
atten_op [label="exp(W) @ (exp(K) * V)\n/ exp(W) * exp(K)", shape="box"]
atten [label="[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
mul [label="*", shape="box", width=3]
llo [label="LinearO\n[ProjSize, HidSize]", shape="box"]
oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
inp -> llq
inp -> llk
inp -> llv
llq -> q
llk -> k
llv -> v
q -> σ
w -> atten_op
k -> atten_op
v -> atten_op
atten_op -> atten
σ -> mul
atten -> mul
mul -> llo
llo -> oup
}
AFTSimple
digraph AFTSimple {
rankdir=BT
node [
style=filled,
color=Black
fontcolor=White,
fillcolor="#30638e",
fontname="SimHei",
fontsize=32,
width=5, height=2,
]
inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"]
llk [label="LinearK\n[HidSize, ProjSize]", shape="box"]
llv [label="LinearV\n[HidSize, ProjSize]", shape="box"]
q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
σ [label="Sigmoid", shape="box", width=3]
atten_op [label="sum(softmax(K, 1) * V, 1)", shape="box"]
atten [label="[BatchSize, 1, ProjSize]", shape="Mrecord"]
mul [label="*", shape="box", width=3]
llo [label="LinearO\n[ProjSize, HidSize]", shape="box"]
oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
inp -> llq
inp -> llk
inp -> llv
llq -> q
llk -> k
llv -> v
q -> σ
k -> atten_op
v -> atten_op
atten_op -> atten
σ -> mul
atten -> mul
mul -> llo
llo -> oup
}
AFTLocal
digraph AFTLocal {
rankdir=BT
node [
style=filled,
color=Black
fontcolor=White,
fillcolor="#30638e",
fontname="SimHei",
fontsize=32,
width=5, height=2,
]
inp [label="输入\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
llq [label="LinearQ\n[HidSize, ProjSize]", shape="box"]
llk [label="LinearK\n[HidSize, ProjSize]", shape="box"]
llv [label="LinearV\n[HidSize, ProjSize]", shape="box"]
w [label="W:Param\n[SeqLen, SeqLen]", shape="Mrecord"]
mask [label="mask\n[SeqLen, SeqLen]\nabs(i - j) < S? 1: 0", shape="box"]
q [label="Q\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
k [label="K\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
v [label="V\n[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
σ [label="Sigmoid", shape="box", width=3]
atten_op [label="exp(W) @ (exp(K) * V)\n/ exp(W) * exp(K)", shape="box"]
atten [label="[BatchSize,\n SeqLen,\n ProjSize]", shape="Mrecord"]
mul [label="*", shape="box", width=3]
llo [label="LinearO\n[ProjSize, HidSize]", shape="box"]
oup [label="输出\n[BatchSize,\n SeqLen,\n HidSize]", shape="Mrecord"]
inp -> llq
inp -> llk
inp -> llv
llq -> q
llk -> k
llv -> v
q -> σ
w -> mask
mask -> atten_op
k -> atten_op
v -> atten_op
atten_op -> atten
σ -> mul
atten -> mul
mul -> llo
llo -> oup
文章来源:https://blog.csdn.net/wizardforcel/article/details/135500782
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!