Cross attention

相同点:

机制:两者都使用了点积注意力机制(scaled dot-product attention)来计算注意力权重。
参数:无论是自注意力还是交叉注意力,它们都有查询(Query)、键(Key)和值(Value)的概念。
计算:两者都使用查询和键之间的点积,然后应用softmax函数来计算注意力权重。
输出:在计算完注意力权重后,两者都将这些权重应用于值来得到输出。
可变性:两者都可以通过掩码(masking)来控制某些位置不被其他位置关注。

不同点:

Self Attention: 查询、键和值都来自同一个输入序列。这使得模型能够关注输入序列中的其他部分以产生一个位置的输出。主要目的是捕捉输入序列内部的依赖关系。在Transformer的编码器(Encoder)和解码器(Decoder)的每一层都有自注意力。它允许输入序列的每个部分关注序列中的其他部分。

Cross Attention: 查询来自一个输入序列,而键和值来自另一个输入序列。这在诸如序列到序列模型(如机器翻译)中很常见,其中一个序列需要“关注”另一个序列。目的是使一个序列能够关注另一个不同的序列。主要出现在Transformer的解码器。它允许解码器关注编码器的输出,这在机器翻译等任务中尤为重要。

Code

import torch
import torch.nn as nn
import torch.nn.functional as F 

class CrossAttention(nn.Module):
    def __init__(self, dim: int, heads = 8):
        super().__init__()
        self.dim = dim
        self.heads = heads 
        self.scale = dim ** -0.5
        self.query = nn.Linear(dim, dim) # project to add more variability
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)

    def forward(self, queries, keys, values, mask = None):
        b, n, _, h = *queries.shape, self.heads  # b: batch size, n: sequence length, h: heads, _ for dim
        queries = self.query(queries)
        keys = self.key(keys)
        values = self.value(values)

        # the last dim denote the actual feature vector for per head per batch(dim per head)
        queries = queries.view(b, n, h, -1).transpose(1, 2) # torch.size([1, 8, 10, 64]) 每一个头分到64个dim
        keys = keys.view(b, n, h, -1).transpose(1, 2)
        values = values.view(b, n, h, -1).transpose(1, 2) 
        # einsum function do dot product on the last two dim: 
        dots = torch.einsum('bhid, bhjd -> bhij', queries, keys) * self.scale 

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True) # (b, n) -> (b, n+1)
            assert mask.shape[-1] == dots.shape[-1], 'Mask has incorrect dimensions'
            mask = mask[:, None, :].expand(-1, n, -1)
            dots.masked_fill_(~mask, float('-inf'))

        attn = dots.softmax(dim = -1)
        out = torch.einsum('bhij, bhjd -> bhid', attn, values)
        out = out.transpose(1, 2).contiguous().view(b, n, -1) # contiguous函数显式将内存中对应ndarray的数据连续存储
        return out

if __name__ == '__main__':
    model = CrossAttention(512)
    queries = torch.randn(1, 10, 512)
    keys = torch.randn(1, 10, 512)
    values = torch.randn(1, 10, 512)
    out = model(queries, keys, values)
    print(out.shape)

没看懂einsum的一些细节, 后续有机会在进行补充.

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇