09401f51创建于 2025年3月22日历史提交
import torch
import torch.nn as nn

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests


class TestNpuAttnSoftMax(TestCase):
    def npu_attn_softmax(self, attention_logits):
        torch_npu.npu_attn_softmax_(attention_logits)

    def npu_attn_softmax_backward(self, attention_logits, grad_output, v):
        torch_npu.npu_attn_softmax_backward_(
            attention_logits,
            grad_output,
            v
        )

    def golden_calc(self, attention_logits, grad_output, values):
        attention_logits_golden = attention_logits.detach().clone()
        attention_logits_golden.requires_grad = True
        softmax = nn.Softmax(dim=-1)
        softmax_output_golden = softmax(attention_logits_golden)
        output = torch.matmul(softmax_output_golden, values)
        output.backward(grad_output)
        grad_x_golden = attention_logits_golden.grad
        return softmax_output_golden, grad_x_golden

    def test_npu_attn_softmax(self):
        B = 10
        q_s = 4096
        kv_s = 4096
        H = 128
        attention_logits = torch.randn(B, q_s, kv_s).npu()
        grad_output = torch.randn(B, q_s, H).npu()
        values = torch.randn(B, kv_s, H).npu()

        softmax_output_golden, grad_x_golden = self.golden_calc(attention_logits, grad_output, values)

        self.npu_attn_softmax(attention_logits)
        self.assertRtolEqual(softmax_output_golden, attention_logits)

        self.npu_attn_softmax_backward(attention_logits, grad_output, values)
        self.assertRtolEqual(grad_x_golden, attention_logits)


if __name__ == "__main__":
    run_tests()