DeepSeek最近凭借一己之力击落Nvidia神话,从此硬件不再是AI的决定性条件,低成本训练LLM模型成为现实。之所以能够实现低成本训练大模型,主要归功于DeepSeek底层采用的MLA和MoE结构,MoE结构我们之前已经介绍过,相比于训练一个统一的专业大模型,MoE通过训练N的专家模型,能够更高效的实现推理任务。 
MLA(Multi-head Latent Attention)结构是DeepSeek的核心创新之一,旨在优化Transformer模型中的注意力机制,以减少推理过程中的KV缓存(键值缓存)并提高计算效率。其核心原理如下: 1. MLA的核心思想:低秩联合压缩MLA的核心思想是对键(Key)和值(Value)进行低秩联合压缩。具体来说,传统的Transformer模型在计算注意力时,需要分别计算和存储每个头的键(K)和值(V),这会导致大量的内存消耗和计算开销。而MLA通过引入一个低秩向量(c),将K和V的存储和计算过程进行压缩。具体步骤如下: - 计算低秩向量c:首先计算一个低秩向量c,其维度远小于K和V的维度。
- 转换回K和V:然后通过两个权重矩阵(W)将c转换回K和V。
这样,在推理阶段,只需要存储低秩向量c,而不是完整的K和V,从而显著减少了KV缓存的大小。 2. 权重矩阵的合并与吸收MLA的另一大特点是权重矩阵的合并与吸收。传统的MHA(Multi-head Attention)模型中,每个头都有独立的Q、K、V投影矩阵,而在MLA中,这些矩阵被合并和吸收,以减少参数量和计算量。具体来说: - Q的投影矩阵:Q的投影矩阵可以与后续的投影计算合并。
- K和V的投影矩阵:K和V的投影矩阵也可以进行类似的合并操作。
通过这种合并,MLA在存储和计算上都优于传统的MHA。例如,KV缓存的大小可以从 减少到 ,其中h是head的数量。 3. 解耦的旋转位置编码(RoPE)MLA还对旋转位置编码(RoPE)进行了优化。RoPE对K和Q的位置敏感,但由于低秩向量c改变了位置,因此无法直接使用原有的位置编码。为了解决这个问题,MLA采用了以下方法: - 使用额外的多头Q和共享的K:引入额外的多头Q和共享的K来携带旋转位置编码。
这种解耦的方法确保了位置编码的正确性,同时不影响MLA的低秩压缩和权重合并。 4. MLA的计算流程MLA的计算流程可以概括如下: - 低秩压缩:Q和K路径中分别进行低秩压缩,生成latent / low-rank部分。
- 解耦RoPE:K路径中的一部分进行解耦的RoPE变换。
- 缓存与合并:推理阶段,压缩后的KV矩阵和位置编码后的矩阵被缓存,并进行合并吸收。
5. MLA的优势- 减少KV缓存:通过低秩联合压缩,MLA显著减少了KV缓存的大小。
- 降低计算成本:权重矩阵的合并与吸收减少了参数量和计算量。
- 提高推理效率:在相同的硬件条件下,MLA可以实现更长的上下文长度或更大的batch size,从而提高推理速度或吞吐量。
6. MLA与MQA的比较MLA的计算过程在某些方面与MQA(Multi-query Attention)类似,但MLA在QK维度上进行了RoPE变换,并且V与未施加RoPE的K共享激活值,这使得MLA在保持高效的同时,仍然能够保持较好的模型性能。 QuantML-Qlib实现MoE之前已经集成在我们的框架之中,详见: QuantML-Qlib开发版 | MoE混合专家系统用于提升Transformer表现【附代码】
本次我们将MLA结构也集成进框架之中,代码路径:examples/benchmarks/DeepSeek 
其中MLA结构的代码为:
class MultiHeadLatentAttention(nn.Module): def __init__(self, config: Config): super().__init__()
assert config.v_head_dim is not None , f"v_head_dim is not defined {config.v_head_dim=}" assert config.q_lora_rank is not None , f"q_lora_rank is not defined {config.q_lora_rank=}" assert config.kv_lora_rank is not None , f"kv_lora_rank is not defined {config.kv_lora_rank=}" assert config.rope_head_dim is not None , f"rope_head_dim is not defined {config.rope_head_dim=}"
self.config = config
self.dim = config.d_model self.num_heads = config.num_heads self.v_head_dim = config.v_head_dim
self.nope_head_dim = config.nope_head_dim self.rope_head_dim = config.rope_head_dim
self.q_lora_rank = config.q_lora_rank self.kv_lora_rank = config.kv_lora_rank
self.dropout = config.dropout
self.value_dim = self.num_heads * self.v_head_dim
# this is dims between wQ and wK self.nope_dim = self.num_heads * self.nope_head_dim self.rope_dim = self.num_heads * self.rope_head_dim
# query compression self.compress_q_linear = nn.Linear(self.dim, self.q_lora_rank, bias=False) # W_DQ self.decompress_q_nope = nn.Linear(self.q_lora_rank, self.nope_dim, bias=False) self.decompress_q_rope = nn.Linear(self.q_lora_rank, self.rope_dim, bias=False) self.q_norm = RMSNorm(dim=self.q_lora_rank) # key and value compression self.compress_kv_linear = nn.Linear(self.dim, self.kv_lora_rank, bias=False) # W_DKV self.decompress_k_nope = nn.Linear(self.kv_lora_rank, self.nope_dim, bias=False) self.decompress_v_linear = nn.Linear(self.kv_lora_rank, self.value_dim, bias=False) self.kv_norm = RMSNorm(dim=self.kv_lora_rank)
self.k_rope_linear = nn.Linear(self.dim, self.rope_head_dim , bias=False) # self.rope_norm = RMSNorm(self.rope_dim) # not in deepseekv2
self.proj = nn.Linear(self.value_dim , self.dim, bias=False) self.res_dropout = nn.Dropout(p=config.dropout)
def forward(self, x: Tensor,mask: torch.Tensor, freqs_cis: Tensor): batch_size, seq_len, _ = x.shape
compressed_q = self.compress_q_linear(x) norm_q = self.q_norm(compressed_q) query_nope:Tensor = self.decompress_q_nope(norm_q) query_rope:Tensor = self.decompress_q_rope(norm_q)
compressed_kv = self.compress_kv_linear(x) norm_kv = self.kv_norm(compressed_kv) key_nope: Tensor = self.decompress_k_nope(norm_kv) value: Tensor = self.decompress_v_linear(norm_kv)
key_rope:Tensor = self.k_rope_linear(x) # norm_rope = self.rope_norm(key_rope)
query_nope = query_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2) query_rope = query_rope.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1,2)
key_rope = key_rope.view(batch_size, seq_len, 1, self.rope_head_dim).transpose(1,2) key_nope = key_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2)
value = value.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1,2)
# *** the line that fixes MLA :) *** key_rope = key_rope/self.num_heads
q_rope,k_rope = apply_rope(query_rope,key_rope, cis=freqs_cis)
q_recombined = torch.empty((batch_size,self.num_heads,seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device) k_recombined = torch.empty((batch_size, self.num_heads, seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device)
q_recombined[:,:,:,:self.nope_head_dim] = query_nope q_recombined[:,:,:,self.nope_head_dim:] = q_rope
# k_rope = torch.repeat_interleave(k_rope, self.num_heads, dim=1) # >> you dont need to do this << # 👇 broadcasting will do replication krope to all heads automagically k_recombined[:,:,:,:self.nope_head_dim] = key_nope k_recombined[:,:,:,self.nope_head_dim:] = k_rope
output = F.scaled_dot_product_attention(q_recombined, k_recombined, value, is_causal=True, dropout_p=self.dropout)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.v_head_dim)
output = self.proj(output) output = self.res_dropout(output) return output
简单测试了一下,500内的IC为0.025,欢迎继续测试优化代码结果。
|