AIKG-SWFT

SWFT Introduction

SWFT is a kernel compiler for Ascend, characterized by minimalist coding and high performance. It currently serves as the backend for AIKG Ascend310P kernel generation.

AIKG-SWFT Analysis

  • SWFT Python expression is more flexible and suitable for LLM code generation
    • While supporting basic Ascend syntax, it extends to higher-level abstractions (e.g., data movement supports arbitrary lengths, with actual repeat, block, and stride settings handled internally by SWFT).
    • The Python syntax can be roughly aligned with the front-end Sketch design (tile2slice, move2copy, vec_compute).
  • Automatic static memory allocation eliminates the need for explicit control over buffers, setting up pipelines, etc.

Reference Code

An example of moe_token_unpermute_op generated by AIKG is as follows:

hidden = 7168
@sub_kernel(core_num=8)
def moe_token_unpermute_op_impl_npu(gm_permute_token, gm_sorted_idx, gm_probs, gm_output, tiling):
    block_idx = get_block_idx()

    # Initialize local output buffer
    ub_idx = move_to_ub(gm_sorted_idx)
    prob_ub = move_to_ub(gm_probs)
    ub_tiling = move_to_ub(tiling)
    token_num = move_to_scalar(ub_tiling[0])
    top_k = move_to_scalar(ub_tiling[1])
    tokens_per_core = move_to_scalar(ub_tiling[2])
    for i in dynamic_loop(tokens_per_core):
        start_token = block_idx * tokens_per_core + i
        tmp_s = Scalar("FP16", 0.0)
        local_out = vector_dup(tmp_s, [1, hidden], False)
        for k in dynamic_loop(top_k):
            # Load sorted_idx
            idx = move_to_scalar(ub_idx[k * token_num + start_token])
            # Load permute_token row
            dst_row = slice_to_ub(gm_permute_token, [idx, 0], slicesize=[1, hidden])
            # Load prob
            prob = move_to_scalar(prob_ub[start_token, k])
            # # Compute weighted row and accumulate
            weighted_row = vmuls(dst_row, prob)
            local_out = vadd(local_out, weighted_row)

        # Write back to GM
        insert_to_gm(gm_output, local_out, [start_token, 0], slicesize=[1, hidden])