NaN problem of nn.MultiHeadAttention
in PyTorch
Published:
MultiHeadAttention
NaN
问题
nn.MultiheadAttention causes gradients to become NaN under some use cases · Issue #41508 · pytorch/pytorch · GitHub 这几天持续跟踪了一下 pytorch 实现的 nn.MultiHeadAttention
计算过程中出现 NaN
的问题。根本原因是 tokenizer 在左侧增加 padding token(只能在左侧加,在右侧加是错误的,LLM 自回归生成,无法跟在 padding token 后面继续生成),导致 causal mask 和 padding mask 合并之后存在 attention matrix 前几行整行被 mask 的情况。pytorch 对于被 mask 部分的处理方式是填充 float("-inf")
,导致经过 softmax
计算之后,整行都是 NaN
。