#include "catlass_kernel.h"
#include "common/optimized_macro_generator.h"
#include "jit_compiler.h"
#include "jit_macro_generator.h"
namespace CatlassKernel {
extern "C" void W8A16Matmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params)
{
auto macros = JitMacroGenerator<TParams>::generate("w8a16_matmul", tParams);
ApplyOptMacros(
macros, params.m, params.n, params.k, tParams.nz("A"), tParams.trans("A"), tParams.nz("B"), tParams.trans("B"));
macros["CATLASS_JIT_ELEMENT_B"] = "half";
macros["CATLASS_JIT_SCALAR"] = "1.0";
macros["CATLASS_JIT_ZERO_POINT"] = "0.0";
auto* entry = JitCompiler::instance().getKernel("w8a16_matmul_impl.cpp", macros, JitKernelType::MIX);
if (entry) {
entry(blockNum, stream, ¶ms);
}
aclrtSynchronizeStream(stream);
}
}