Cross attention(交叉注意力)是深度学习里的一个重要概念,在很多模型尤其是Transformer架构及其变体中被广泛应用。下面为你详细介绍:
概念
在注意力机制里,主要有自注意力(self-attention)和交叉注意力。自注意力关注的是序列内元素之间的关系,也就是让序列里的每个元素都和序列内的其他元素进行交互,以此捕获序列的上下文信息。而交叉注意力则是在两个不同的序列之间计算注意力权重,一个序列作为查询(query),另一个序列作为键(key)和值(value),从而让查询序列能够聚焦于键 - 值序列里的相关信息。
计算过程
交叉注意力的计算和自注意力类似,不过它的查询、键和值来自不同的序列。具体步骤如下:
-
线性变换:对查询序列 $Q$、键序列 $K$ 和值序列 $V$ 分别进行线性变换,得到 $Q'$、$K'$ 和 $V'$。
-
计算注意力分数:计算查询和键之间的相似度,一般使用点积来计算:$scores = Q'K'^T$。
-
应用掩码(可选):在某些情形下,需要应用掩码来屏蔽掉一些不需要关注的位置。
-
计算注意力权重:对分数应用 softmax 函数,从而得到注意力权重:$weights = softmax(scores / \sqrt{d_k})$,这里的 $d_k$ 是键的维度。
-
加权求和:用注意力权重对值进行加权求和,得到输出:$output = weightsV'$。
公式
交叉注意力的输出可以用以下公式表示:
代码示例
以下是使用 PyTorch 实现的简单交叉注意力代码:
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, input_dim, key_dim, value_dim):
super(CrossAttention, self).__init__()
self.query_proj = nn.Linear(input_dim, key_dim)
self.key_proj = nn.Linear(input_dim, key_dim)
self.value_proj = nn.Linear(input_dim, value_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, query, key, value):
Q = self.query_proj(query)
K = self.key_proj(key)
V = self.value_proj(value)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))
attention_weights = self.softmax(scores)
output = torch.matmul(attention_weights, V)
return output
# 示例使用
input_dim = 64
key_dim = 32
value_dim = 32
query = torch.randn(10, 5, input_dim) # 批次大小为 10,序列长度为 5
key = torch.randn(10, 8, input_dim) # 批次大小为 10,序列长度为 8
value = torch.randn(10, 8, input_dim) # 批次大小为 10,序列长度为 8
cross_attn = CrossAttention(input_dim, key_dim, value_dim)
output = cross_attn(query, key, value)
print(output.shape)
应用场景
-
机器翻译:在编码器 - 解码器架构里,解码器能够利用交叉注意力机制关注编码器输出的相关信息,从而生成准确的翻译结果。
-
图像生成:在文本到图像生成任务中,文本嵌入可以作为查询,图像特征作为键和值,通过交叉注意力让生成的图像和文本描述相匹配。
-
多模态任务:在处理多种模态数据(像文本和图像)时,交叉注意力可以帮助不同模态之间进行信息交互。
在推荐系统中,Cross Attention和Target Attention虽然都涉及不同序列间的注意力交互,但两者在定义和应用场景上存在差异,并非完全相同的概念。以下是具体分析:
- Target Attention • 定义:Target Attention的核心是通过计算目标物品(Target Item)与用户历史行为序列之间的关联程度,来衡量用户对目标物品的兴趣。其Query(Q)来自目标物品,Key(K)和Value(V)来自用户行为序列。
• 特点:
• 专注于目标物品与行为序列的交叉关系,忽略行为序列内部的依赖(如DIN模型)。
• 常用于电商推荐,例如通过用户历史点击序列预测对候选商品的兴趣。
• 示例:在阿里ETA模型中,Target Attention通过SimHash技术高效检索与目标物品相关的用户行为子序列。
- Cross Attention • 定义:Cross Attention是更通用的跨序列注意力机制,其Query来自一个序列(如目标序列),Key和Value来自另一个序列(如源序列),用于融合不同来源的信息。
• 特点:
• 不局限于推荐系统,广泛应用于机器翻译、多模态任务(如图像字幕生成)。
• 在推荐系统中,Cross Attention可能用于融合用户画像和物品特征等多模态数据。
• 与Target Attention的关系:Target Attention可视为Cross Attention在推荐系统中的一种特例(Q来自目标物品,K/V来自行为序列)。
- 关键区别
维度 | Target Attention | Cross Attention |
---|---|---|
应用场景 | 推荐系统(用户行为序列与目标物品) | 多领域(如NLP、多模态) |
输入来源 | Q: 目标物品;K/V: 行为序列 | Q/K/V可来自任意两个不同序列 |
设计目标 | 捕捉用户对特定物品的兴趣 | 通用跨序列信息融合 |
- 总结 • 相同点:均通过Query与Key/Value的交互实现跨序列注意力。
• 不同点:
• Target Attention是推荐场景下的专用设计,强调目标物品与用户行为的关联;
• Cross Attention是更通用的机制,适用于任意跨序列交互任务。
在推荐系统中,若需明确区分,建议根据具体任务选择:Target Attention更适合用户-物品交互建模,而Cross Attention更适合多模态或复杂跨序列融合。