* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CROSS_H
#define CROSS_H
#include "kernel_operator.h"
#include "simt_api/asc_simt.h"
#include "cross_struct.h"
#include "cross_tiling_key.h"
using namespace AscendC;
constexpr int32_t CROSS_THREAD_DIM = 1024;
constexpr int32_t CROSS_THREAD_DIM_FOR_INT64 = 512;
struct CrossStridePara {
int32_t ms[8];
int32_t x1s[8];
int32_t x2s[8];
int32_t ys[8];
};
namespace CrossKernel {
template <typename T>
class Cross {
public:
__aicore__ inline Cross() {}
__aicore__ inline ~Cross() {}
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR y,
const CrossRegbaseTilingData& tilingData)
{
totalVectors_ = tilingData.totalVectors;
vectorsPerCore_ = tilingData.vectorsPerCore;
coreNum_ = tilingData.coreNum;
dim_ = tilingData.dim;
dimNum_ = tilingData.dimNum;
dimStride_ = tilingData.dimStride;
formerCore_ = tilingData.formerCore;
usedInt64_ = tilingData.usedInt64;
for (int i = 0; i < 8; i++) {
para_.ms[i] = tilingData.mergedStride[i];
para_.x1s[i] = tilingData.x1Stride[i];
para_.x2s[i] = tilingData.x2Stride[i];
para_.ys[i] = tilingData.yStride[i];
}
x1Gm_.SetGlobalBuffer((__gm__ T*)x1);
x2Gm_.SetGlobalBuffer((__gm__ T*)x2);
yGm_.SetGlobalBuffer((__gm__ T*)y);
}
__aicore__ inline void Process()
{
int32_t block_idx = GetBlockIdx();
int32_t startIdx = (block_idx < formerCore_) ? (vectorsPerCore_+1)*block_idx : formerCore_ + vectorsPerCore_*block_idx;
int32_t count = (block_idx < formerCore_) ? (vectorsPerCore_+1) : vectorsPerCore_;
if (count <= 0) return;
if (!usedInt64_) {
asc_vf_call<SimtCrossCompute<T>>(
dim3(CROSS_THREAD_DIM),
(__gm__ T*)x1Gm_.GetPhyAddr(),
(__gm__ T*)x2Gm_.GetPhyAddr(),
(__gm__ T*)yGm_.GetPhyAddr(),
startIdx, count, dim_, dimNum_, dimStride_, para_
);
} else {
asc_vf_call<SimtCrossComputeInt64<T>>(
dim3(CROSS_THREAD_DIM_FOR_INT64),
(__gm__ T*)x1Gm_.GetPhyAddr(),
(__gm__ T*)x2Gm_.GetPhyAddr(),
(__gm__ T*)yGm_.GetPhyAddr(),
startIdx, count, dim_, dimNum_, dimStride_, para_
);
}
}
private:
template <typename U>
__simt_vf__ LAUNCH_BOUND(CROSS_THREAD_DIM)
static void SimtCrossCompute(
__gm__ U* x1Gm, __gm__ U* x2Gm, __gm__ U* yGm,
int32_t startIdx, int32_t count, int32_t dim, int32_t dimNum, int32_t dimStride,
CrossStridePara para)
{
const int32_t idx = threadIdx.x;
const int32_t step = blockDim.x;
const int32_t sx1 = para.x1s[dim];
const int32_t sx2 = para.x2s[dim];
const int32_t sy = para.ys[dim];
int32_t i = idx;
while (i < count) {
int32_t v = startIdx + i;
int32_t x1b = 0, x2b = 0, yb = 0;
for (int d = 0; d < dimNum; d++) {
int32_t c = v / para.ms[d];
v -= c * para.ms[d];
x1b += c * para.x1s[d];
x2b += c * para.x2s[d];
yb += c * para.ys[d];
}
float a0 = (float)x1Gm[x1b];
float a1 = (float)x1Gm[x1b + sx1];
float a2 = (float)x1Gm[x1b + 2*sx1];
float b0 = (float)x2Gm[x2b];
float b1 = (float)x2Gm[x2b + sx2];
float b2 = (float)x2Gm[x2b + 2*sx2];
yGm[yb] = (U)(a1*b2 - a2*b1);
yGm[yb + sy] = (U)(a2*b0 - a0*b2);
yGm[yb + 2*sy] = (U)(a0*b1 - a1*b0);
i += step;
}
}
template <typename U>
__simt_vf__ LAUNCH_BOUND(CROSS_THREAD_DIM_FOR_INT64)
static void SimtCrossComputeInt64(
__gm__ U* x1Gm, __gm__ U* x2Gm, __gm__ U* yGm,
int32_t startIdx, int32_t count, int32_t dim, int32_t dimNum, int32_t dimStride,
CrossStridePara para)
{
const int32_t idx = threadIdx.x;
const int32_t step = blockDim.x;
const int64_t sx1 = (int64_t)para.x1s[dim];
const int64_t sx2 = (int64_t)para.x2s[dim];
const int64_t sy = (int64_t)para.ys[dim];
int32_t i = idx;
while (i < count) {
int32_t v = startIdx + i;
int64_t x1b = 0;
int64_t x2b = 0;
int64_t yb = 0;
for (int d = 0; d < dimNum; d++) {
int64_t m = (int64_t)para.ms[d];
int64_t x1s = (int64_t)para.x1s[d];
int64_t x2s = (int64_t)para.x2s[d];
int64_t ys = (int64_t)para.ys[d];
int64_t c = (int64_t)v / m;
v = (int32_t)((int64_t)v - c * m);
x1b += c * x1s;
x2b += c * x2s;
yb += c * ys;
}
float a0 = (float)x1Gm[x1b];
float a1 = (float)x1Gm[x1b + sx1];
float a2 = (float)x1Gm[x1b + 2 * sx1];
float b0 = (float)x2Gm[x2b];
float b1 = (float)x2Gm[x2b + sx2];
float b2 = (float)x2Gm[x2b + 2 * sx2];
yGm[yb] = (U)(a1 * b2 - a2 * b1);
yGm[yb + sy] = (U)(a2 * b0 - a0 * b2);
yGm[yb + 2 * sy] = (U)(a0 * b1 - a1 * b0);
i += step;
}
}
GlobalTensor<T> x1Gm_;
GlobalTensor<T> x2Gm_;
GlobalTensor<T> yGm_;
int32_t totalVectors_{0};
int32_t vectorsPerCore_{0};
int32_t coreNum_{0};
int32_t dim_{0};
int32_t dimNum_{0};
int32_t dimStride_{1};
int32_t formerCore_{0};
bool usedInt64_{false};
CrossStridePara para_;
};
}
#endif