相同点:
机制:两者都使用了点积注意力机制(scaled dot-product attention)来计算注意力权重。
参数:无论是自注意力还是交叉注意力,它们都有查询(Query)、键(Key)和值(Value)的概念。
计算:两者都使用查询和键之间的点积,然后应用softmax函数来计算注意力权重。
输出:在计算完注意力权重后,两者都将这些权重应用于值来得到输出。
可变性:两者都可以通过掩码(masking)来控制某些位置不被其他位置关注。
不同点:
Self Attention: 查询、键和值都来自同一个输入序列。这使得模型能够关注输入序列中的其他部分以产生一个位置的输出。主要目的是捕捉输入序列内部的依赖关系。在Transformer的编码器(Encoder)和解码器(Decoder)的每一层都有自注意力。它允许输入序列的每个部分关注序列中的其他部分。
Cross Attention: 查询来自一个输入序列,而键和值来自另一个输入序列。这在诸如序列到序列模型(如机器翻译)中很常见,其中一个序列需要“关注”另一个序列。目的是使一个序列能够关注另一个不同的序列。主要出现在Transformer的解码器。它允许解码器关注编码器的输出,这在机器翻译等任务中尤为重要。
Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, dim: int, heads = 8):
super().__init__()
self.dim = dim
self.heads = heads
self.scale = dim ** -0.5
self.query = nn.Linear(dim, dim) # project to add more variability
self.key = nn.Linear(dim, dim)
self.value = nn.Linear(dim, dim)
def forward(self, queries, keys, values, mask = None):
b, n, _, h = *queries.shape, self.heads # b: batch size, n: sequence length, h: heads, _ for dim
queries = self.query(queries)
keys = self.key(keys)
values = self.value(values)
# the last dim denote the actual feature vector for per head per batch(dim per head)
queries = queries.view(b, n, h, -1).transpose(1, 2) # torch.size([1, 8, 10, 64]) 每一个头分到64个dim
keys = keys.view(b, n, h, -1).transpose(1, 2)
values = values.view(b, n, h, -1).transpose(1, 2)
# einsum function do dot product on the last two dim:
dots = torch.einsum('bhid, bhjd -> bhij', queries, keys) * self.scale
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True) # (b, n) -> (b, n+1)
assert mask.shape[-1] == dots.shape[-1], 'Mask has incorrect dimensions'
mask = mask[:, None, :].expand(-1, n, -1)
dots.masked_fill_(~mask, float('-inf'))
attn = dots.softmax(dim = -1)
out = torch.einsum('bhij, bhjd -> bhid', attn, values)
out = out.transpose(1, 2).contiguous().view(b, n, -1) # contiguous函数显式将内存中对应ndarray的数据连续存储
return out
if __name__ == '__main__':
model = CrossAttention(512)
queries = torch.randn(1, 10, 512)
keys = torch.randn(1, 10, 512)
values = torch.randn(1, 10, 512)
out = model(queries, keys, values)
print(out.shape)
没看懂einsum
的一些细节, 后续有机会在进行补充.