from ..core.dtype import KnownTypes as KT
from ..core.struct import Field, Struct
from ..core.utils import global_builder
from .types import MatmulConfig
class MatmulApiStaticTiling(Struct):
used_core_num = Field(dtype=KT.int32, default=-1, name="usedCoreNum")
m = Field(dtype=KT.int32, default=-1, name="M")
n = Field(dtype=KT.int32, default=-1, name="N")
k_a = Field(dtype=KT.int32, default=-1, name="Ka")
k_b = Field(dtype=KT.int32, default=-1, name="Kb")
single_core_m = Field(dtype=KT.int32, default=-1, name="singleCoreM")
single_core_n = Field(dtype=KT.int32, default=-1, name="singleCoreN")
single_core_k = Field(dtype=KT.int32, default=-1, name="singleCoreK")
base_m = Field(dtype=KT.int32, default=-1, name="baseM")
base_n = Field(dtype=KT.int32, default=-1, name="baseN")
base_k = Field(dtype=KT.int32, default=-1, name="baseK")
depth_a1 = Field(dtype=KT.int32, default=-1, name="depthA1")
depth_b1 = Field(dtype=KT.int32, default=-1, name="depthB1")
step_m = Field(dtype=KT.int32, default=-1, name="stepM")
step_n = Field(dtype=KT.int32, default=-1, name="stepN")
is_bias = Field(dtype=KT.int32, default=-1, name="isBias")
trans_length = Field(dtype=KT.int32, default=-1, name="transLength")
iterate_order = Field(dtype=KT.int32, default=-1, name="iterateOrder")
share_mode = Field(dtype=KT.int32, default=-1, name="shareMode")
share_l1_size = Field(dtype=KT.int32, default=-1, name="shareL1Size")
share_l0c_size = Field(dtype=KT.int32, default=-1, name="shareL0CSize")
share_ub_size = Field(dtype=KT.int32, default=-1, name="shareUbSize")
step_k_a = Field(dtype=KT.int32, default=-1, name="stepKa")
step_k_b = Field(dtype=KT.int32, default=-1, name="stepKb")
depth_a_l1_cache_ub = Field(dtype=KT.int32, default=-1, name="depthAL1CacheUB")
depth_b_l1_cache_ub = Field(dtype=KT.int32, default=-1, name="depthBL1CacheUB")
db_l0a = Field(dtype=KT.int32, default=-1, name="dbL0A")
db_l0b = Field(dtype=KT.int32, default=-1, name="dbL0B")
db_l0c = Field(dtype=KT.int32, default=-1, name="dbL0C")
a_layout_info_b = Field(dtype=KT.int32, default=-1, name="ALayoutInfoB")
a_layout_info_s = Field(dtype=KT.int32, default=-1, name="ALayoutInfoS")
a_layout_info_n = Field(dtype=KT.int32, default=-1, name="ALayoutInfoN")
a_layout_info_g = Field(dtype=KT.int32, default=-1, name="ALayoutInfoG")
a_layout_info_d = Field(dtype=KT.int32, default=-1, name="ALayoutInfoD")
b_layout_info_b = Field(dtype=KT.int32, default=-1, name="BLayoutInfoB")
b_layout_info_s = Field(dtype=KT.int32, default=-1, name="BLayoutInfoS")
b_layout_info_n = Field(dtype=KT.int32, default=-1, name="BLayoutInfoN")
b_layout_info_g = Field(dtype=KT.int32, default=-1, name="BLayoutInfoG")
b_layout_info_d = Field(dtype=KT.int32, default=-1, name="BLayoutInfoD")
c_layout_info_b = Field(dtype=KT.int32, default=-1, name="CLayoutInfoB")
c_layout_info_s1 = Field(dtype=KT.int32, default=-1, name="CLayoutInfoS1")
c_layout_info_n = Field(dtype=KT.int32, default=-1, name="CLayoutInfoN")
c_layout_info_g = Field(dtype=KT.int32, default=-1, name="CLayoutInfoG")
c_layout_info_s2 = Field(dtype=KT.int32, default=-1, name="CLayoutInfoS2")
batch_num = Field(dtype=KT.int32, default=-1, name="BatchNum")
mx_type_para = Field(dtype=KT.int32, default=-1, name="mxTypePara")
cfg = MatmulConfig
@classmethod
def get_ir_type(cls):
return global_builder.get_ir_builder().get_asc_MatmulApiStaticTilingType()
class RmsNormTiling(Struct):
b_length = Field(dtype=KT.int32, default=0, name="bLength")
s_length = Field(dtype=KT.int32, default=0, name="sLength")
h_length = Field(dtype=KT.int32, default=0, name="hLength")
original_h_length = Field(dtype=KT.int32, default=0, name="originalHLength")
reciprocal_of_h_length = Field(dtype=KT.float32, default=0.0, name="reciprocalOfHLength")
main_bsh_length = Field(dtype=KT.int32, default=0, name="mainBshLength")
main_bs_length = Field(dtype=KT.int32, default=0, name="mainBsLength")
main_bs_length_align = Field(dtype=KT.int32, default=0, name="mainBsLengthAlign")
loop_round = Field(dtype=KT.int32, default=0, name="loopRound")
input_tail_pos = Field(dtype=KT.int32, default=0, name="inputTailPos")
tail_bsh_length = Field(dtype=KT.int32, default=0, name="tailBshLength")
tail_bs_length = Field(dtype=KT.int32, default=0, name="tailBsLength")
@classmethod
def get_ir_type(cls):
return global_builder.get_ir_builder().get_asc_RmsNormTilingType()
class SoftmaxTiling(Struct):
src_m = Field(dtype=KT.int32, default=0, name="srcM")
src_k = Field(dtype=KT.int32, default=0, name="srcK")
src_size = Field(dtype=KT.int32, default=0, name="srcSize")
out_max_m = Field(dtype=KT.int32, default=0, name="outMaxM")
out_max_k = Field(dtype=KT.int32, default=0, name="outMaxK")
out_max_size = Field(dtype=KT.int32, default=0, name="outMaxSize")
split_m = Field(dtype=KT.int32, default=0, name="splitM")
split_k = Field(dtype=KT.int32, default=0, name="splitK")
split_size = Field(dtype=KT.int32, default=0, name="splitSize")
reduce_m = Field(dtype=KT.int32, default=0, name="reduceM")
reduce_k = Field(dtype=KT.int32, default=0, name="reduceK")
reduce_size = Field(dtype=KT.int32, default=0, name="reduceSize")
range_m = Field(dtype=KT.int32, default=0, name="rangeM")
tail_m = Field(dtype=KT.int32, default=0, name="tailM")
tail_split_size = Field(dtype=KT.int32, default=0, name="tailSplitSize")
tail_reduce_size = Field(dtype=KT.int32, default=0, name="tailReduceSize")
@classmethod
def get_ir_type(cls):
return global_builder.get_ir_builder().get_asc_SoftMaxTilingType()
class TCubeTiling(Struct):
used_core_num = Field(dtype=KT.int32, default=0, name="usedCoreNum")
m = Field(dtype=KT.int32, default=0, name="M")
n = Field(dtype=KT.int32, default=0, name="N")
k_a = Field(dtype=KT.int32, default=0, name="Ka")
k_b = Field(dtype=KT.int32, default=0, name="Kb")
single_core_m = Field(dtype=KT.int32, default=0, name="singleCoreM")
single_core_n = Field(dtype=KT.int32, default=0, name="singleCoreN")
single_core_k = Field(dtype=KT.int32, default=0, name="singleCoreK")
base_m = Field(dtype=KT.int32, default=0, name="baseM")
base_n = Field(dtype=KT.int32, default=0, name="baseN")
base_k = Field(dtype=KT.int32, default=0, name="baseK")
depth_a1 = Field(dtype=KT.int32, default=0, name="depthA1")
depth_b1 = Field(dtype=KT.int32, default=0, name="depthB1")
step_m = Field(dtype=KT.int32, default=0, name="stepM")
step_n = Field(dtype=KT.int32, default=0, name="stepN")
is_bias = Field(dtype=KT.int32, default=0, name="isBias")
trans_length = Field(dtype=KT.int32, default=0, name="trans_length")
iterate_order = Field(dtype=KT.int32, default=0, name="iterateOrder")
share_mode = Field(dtype=KT.int32, default=0, name="shareMode")
share_l1_size = Field(dtype=KT.int32, default=0, name="shareL1Size")
share_l0c_size = Field(dtype=KT.int32, default=0, name="shareL0CSize")
share_ub_size = Field(dtype=KT.int32, default=0, name="shareUbSize")
batch_m = Field(dtype=KT.int32, default=0, name="batchM")
batch_n = Field(dtype=KT.int32, default=0, name="batchN")
single_batch_m = Field(dtype=KT.int32, default=0, name="singleBatchM")
single_batch_n = Field(dtype=KT.int32, default=0, name="singleBatchN")
step_k_a = Field(dtype=KT.int32, default=0, name="stepKa")
step_k_b = Field(dtype=KT.int32, default=0, name="stepKb")
depth_a_l1_cache_ub = Field(dtype=KT.int32, default=0, name="depthAL1CacheUB")
depth_b_l1_cache_ub = Field(dtype=KT.int32, default=0, name="depthBL1CacheUB")
db_l0a = Field(dtype=KT.int32, default=0, name="dbL0A")
db_l0b = Field(dtype=KT.int32, default=0, name="dbL0B")
db_l0c = Field(dtype=KT.int32, default=0, name="dbL0C")
a_layout_info_b = Field(dtype=KT.int32, default=0, name="ALayoutInfoB")
a_layout_info_s = Field(dtype=KT.int32, default=0, name="ALayoutInfoS")
a_layout_info_n = Field(dtype=KT.int32, default=0, name="ALayoutInfoN")
a_layout_info_g = Field(dtype=KT.int32, default=0, name="ALayoutInfoG")
a_layout_info_d = Field(dtype=KT.int32, default=0, name="ALayoutInfoD")
b_layout_info_b = Field(dtype=KT.int32, default=0, name="BLayoutInfoB")
b_layout_info_s = Field(dtype=KT.int32, default=0, name="BLayoutInfoS")
b_layout_info_n = Field(dtype=KT.int32, default=0, name="BLayoutInfoN")
b_layout_info_g = Field(dtype=KT.int32, default=0, name="BLayoutInfoG")
b_layout_info_d = Field(dtype=KT.int32, default=0, name="BLayoutInfoD")
c_layout_info_b = Field(dtype=KT.int32, default=0, name="CLayoutInfoB")
c_layout_info_s1 = Field(dtype=KT.int32, default=0, name="CLayoutInfoS1")
c_layout_info_n = Field(dtype=KT.int32, default=0, name="CLayoutInfoN")
c_layout_info_g = Field(dtype=KT.int32, default=0, name="CLayoutInfoG")
c_layout_info_s2 = Field(dtype=KT.int32, default=0, name="CLayoutInfoS2")
batch_num = Field(dtype=KT.int32, default=0, name="BatchNum")
@classmethod
def get_ir_type(cls):
return global_builder.get_ir_builder().get_asc_TCubeTilingType()