| fix: correct grad_probs dtype in npu_moe_token_unpermute_grad to match probs dtype
Co-authored-by: wang_ziqi<wangziqi4@huawei.com>
# message auto-generated for no-merge-commit merge:
!4834 merge fix/grad-probs-dtype-26.0.0 into 26.0.0
fix: correct grad_probs dtype in npu_moe_token_unpermute_grad to match probs dtype
Created-by: wang-ziqi-code
Commit-by: wang_ziqi
Merged-by: ascend-robot
Description: # 【合入来源】
> <font color="red">**如有社区issue,请关联issue链接**</font>\
> <font color="red">**请勿携带内部流程信息(需求链接、问题单、内部issue等)**</font>
- [ ] 需求
- [x] 问题单
- [ ] issue/工单
- [ ] 重构优化
- [ ] 资料更新
# 【修改方案】
修复 npu_moe_token_unpermute_grad 算子中 grad_probs 输出数据类型推导错误的问题。
1. **问题根因**:grad_probs 作为 probs 的梯度,其 dtype 应与 probs 保持一致,但当前实现中错误地使用了 grad_unpermuted_tokens 的 dtype。当 probs 与 grad_unpermuted_tokens 数据类型不同时(如 MoE 混合精度训练中 probs 为 float32、grad_unpermuted_tokens 为 bfloat16),grad_probs 的 dtype 推导结果错误,导致精度丢失,进而引发网络 loss 为 NaN。
2. **修改内容**:
- op_plugin/python/meta/_meta_registrations.py:将 torch.empty_like(probs, dtype=grad_unpermuted_tokens.dtype) 改为 torch.empty_like(probs, dtype=probs.dtype),确保 grad_probs 的 dtype 与 probs 一致;probs 为 None 时 grad_probs 也为 None,不创建张量
- op_plugin/config/op_plugin_functions.yaml:将 grad_probs 的 dtype 字段从 grad_unpermuted_tokens 改为 probs.has_value() ? probs->scalar_type() : permuted_tokens.scalar_type()。由于 probs 是 Tensor?(optional)类型,代码生成器不会自动追加 .scalar_type(),需要显式写出完整 C++ 表达式;probs 为空时 grad_probs 为空张量,使用 permuted_tokens.scalar_type() 作为占位 dtype
3. **为什么之前未被发现**:现有测试中 probs 和 grad_unpermuted_tokens 使用相同 dtype(均为 bfloat16),此时 dtype=grad_unpermuted_tokens.dtype 恰好等于 probs.dtype,bug 被掩盖。仅在混合精度场景下才会暴露。
# 【资料变更】
不涉及
# 【接口变更】
不涉及。本次修复仅纠正 grad_probs 输出的 dtype 推导逻辑,算子签名和接口行为无变更。
# 【功能验证】
1. 测试场景:npu_moe_token_unpermute_grad 算子在 probs 与 grad_unpermuted_tokens dtype 不同时的输出 dtype 正确性
2. 测试方法:已有测试用例 test_fake_tensor.py 中包含 self.assertEqual(grad_probs.dtype, probs.dtype) 断言,与修复后的行为一致
3. 验证方式:通过 Meta Tensor / FakeTensor 机制验证 grad_probs 的 dtype 与 probs 一致;端到端混合精度训练场景下 loss 不再出现 NaN
# 【CheckList】
> PR提交人对以下CheckList自检项进行全量自检,自检通过或不涉及,均修改 [ ] 为 [x]
- [x] 代码注释完备,正确记录错误日志
- [x] 代码实现进行了返回值、空指针等校验
- [x] PR标题正确使用类型标签,如:feat、fix、refactor、docs、test等
- [x] PR持续集成流水线(CI)执行通过,代码检查无异常
See merge request: Ascend/op-plugin!4834 | 1 个月前 |