多头注意力

其实这里没问题,只是把多个头的输入的变换矩阵拼成一个大矩阵而已

这里确实是query_size,从w_q的形状就能看出来。
这个多头注意力看了好久终于算看明白了。
其实,小结中“多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。” 这句总结有助于理解代码,本文将Q,K,V分成不同的子空间,对每个子空间进行注意力汇聚,然后再将输出合并,代码实现的很巧妙,精华就是transpose_qkv和transpose_output这两个函数,值得细品

如果 num_hiddens ≠ key_size 那么程序就跑不了,会报矩阵不能相乘,既然这样那为什么还要把这两个参数分离开,明明只有一个自由度

说出了我的疑惑。。。。。。。。。。。。。。。。。。

你这段话非常正确,本质就是把每个query,key,value都切分成num_heads段,对每一段应用注意力机制,从而更好提取局部特征,也就是作者说的子空间。

1 Like

是不是可以这样理解,沐神写的那个代码直接把每个头的全连接层拼在了一起,也就是说那个上边写的num_hiddens其实并不是每个头的num_hiddens,而是这个值 * num_headers,