import torch
import torch_npu
import torchair
import mie_ops
import numpy as np
import torch.nn as nn
from torch_npu.testing.testcase import TestCase, run_tests
device_id = 0
torch_npu.npu.set_device(int(device_id))
def _lightning_indexer(query, key, weights, actual_seq_lengths_query, actual_seq_lengths_key, layout_key,
selected_count, sparse_mode):
layout_query = "BSND"
layout_key = layout_key
selected_count = selected_count
sparse_mode = sparse_mode
s1 = query.shape[1]
query = np.transpose(query, axes=(0, 2, 1, 3)).astype(np.float32)
key = np.transpose(key.reshape(query.shape[0], -1, key.shape[-1]), axes=(0, 2, 1)).astype(np.float32)
weights = np.transpose(weights, axes=(0, 2, 1, 3)).astype(np.float32)
relu_out = np.maximum(0, query.reshape(query.shape[0], -1, query.shape[-1]) @ key)
relu_out = relu_out.reshape(query.shape[0], -1, s1, key.shape[-1])
weight_out = relu_out * weights
reduce_out = np.sum(weight_out, axis=1, keepdims=True)
s1 = reduce_out.shape[2]
s2 = reduce_out.shape[3]
if sparse_mode == 3:
for i in range(s1):
reduce_out[:, :, -1 - i, s2 - i:] = float('-inf')
sorted_indices = np.argsort(-reduce_out, kind="stable", axis=-1)
if sparse_mode == 3:
for i in range(s1):
sorted_indices[:, :, -1 - i, s2 - i:] = -1
sorted_res = sorted_indices[..., :selected_count]
pad_width = [(0, 0)] * sorted_res.ndim
pad_width[-1] = (0, selected_count - sorted_res.shape[-1])
sorted_res = np.pad(sorted_res, pad_width, mode='constant', constant_values=-1)
return sorted_res.astype(np.int32)
class TestCustomLightningIndexer(TestCase):
def test_lightning_indexer_eager(self):
B = 1
S1 = 1
S2 = 8192
N1 = 64
N2 = 1
D = 128
block_size = 256
T = 8192
layout_query = 'BSND'
np.random.seed(0)
if layout_query == 'BSND':
query = torch.tensor(np.random.uniform(-10, 10, (B, S1, N1, D))).to(torch.bfloat16)
else:
query = torch.tensor(np.random.uniform(-10, 10, (T, N1, D))).to(torch.bfloat16)
key = torch.tensor(np.random.uniform(-10, 10, (B * (S2 // block_size), block_size, N2, D))).to(torch.bfloat16)
weights = torch.tensor(np.random.uniform(-1, 1, (B, S1, N1, 1))).to(torch.bfloat16)
actual_seq_lengths_query = torch.tensor(np.random.uniform(S1, S1, (B))).to(torch.int32)
actual_seq_lengths_key = torch.tensor(np.random.uniform(S2, S2, (B))).to(torch.int32)
block_table = torch.tensor([range(B * S2 // block_size)], dtype=torch.int32).reshape(B, -1)
layout_key = 'PA_BSND'
selected_count = 2048
sparse_mode = 3
cpuout = _lightning_indexer(query.to(torch.float).numpy(), key.to(torch.float).numpy(),
weights.to(torch.float).numpy(), actual_seq_lengths_query.numpy(),
actual_seq_lengths_key.numpy(),
layout_key, selected_count, sparse_mode)
torch_npu.npu.set_device(int(device_id))
query = query.to("npu:%s" % device_id)
key = key.to("npu:%s" % device_id)
weights = weights.to("npu:%s" % device_id)
actual_seq_lengths_query = actual_seq_lengths_query.to("npu:%s" % device_id)
actual_seq_lengths_key = actual_seq_lengths_key.to("npu:%s" % device_id)
block_table = block_table.to("npu:%s" % device_id)
print(f'======================== PTA eager BEGIN ========================')
npu_out = torch.ops.mie_ops.npu_lightning_indexer(
query, key, weights, actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key, block_table=block_table, layout_query=layout_query,
layout_key=layout_key, selected_count=selected_count, sparse_mode=sparse_mode)
npu_out = npu_out.reshape(B, S1, selected_count).cpu()
cpuout = cpuout.reshape(B, S1, selected_count)
for i in range(B):
for j in range(S1):
for k in range(selected_count):
res = npu_out[i][j][k] == cpuout[i][j][k]
if not res.all():
print("B S K npu cpu = ", i, j, k, npu_out[i][j][k], cpuout[i][j][k])
print(f'======================== PTA eager FINISH ========================')
if __name__ == "__main__":
run_tests()