#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------

"""
Function:
The replay funtion entry
Copyright Information:
Huawei Technologies Co., Ltd. All Rights Reserved © 2020
"""

import sys
import os
import stat
import re
import const_var


def gen_tiling(tiling_header_file: str, tiling_file_out: str):
    if not os.path.exists(tiling_header_file):
        print("warning: no userdef tiling header file: ", tiling_header_file)
        return
    print("generate tiling def header file: ", tiling_file_out)
    tmp_name = os.path.splitext(os.path.basename(tiling_header_file))[0].upper()
    tiling_source = '#ifndef __{}_H__\n'.format(tmp_name)
    tiling_source += '#define __{}_H__\n\n'.format(tmp_name)
    tiling_source += '#include <cstdint>\n'
    tiling_source += '#include <cstring>\n\n'
    tiling_source += '#include "kernel_tiling/kernel_tiling.h"\n\n'
    end_source = ""
    pattern = re.compile(r'[(](.*)[)]', re.S)
    with open(tiling_header_file, 'r') as fd:
        lines = fd.readlines()
        for line in lines:
            line = line.strip()
            if (line.startswith('BEGIN_TILING_DATA_DEF')):
                tiling_source += '#pragma pack(1)\n'
                tiling_source += 'struct '
                struct_def  = re.findall(pattern, line)[0]
                tiling_source += struct_def + ' {\n'
            elif (line.startswith('TILING_DATA_FIELD_DEF_ARR')):
                field_params = re.findall(pattern, line)[0]
                fds = field_params.split(',')
                tiling_source += '    {} {}[{}] = {{}};\n'.format(fds[0].strip(), fds[2].strip(), fds[1].strip())
            elif (line.startswith('TILING_DATA_FIELD_DEF_STRUCT')):
                field_params = re.findall(pattern, line)[0]
                fds = field_params.split(',')
                tiling_source += '    {} {};\n'.format(fds[0].strip(), fds[1].strip())
            elif (line.startswith('TILING_DATA_FIELD_DEF')):
                field_params = re.findall(pattern, line)[0]
                fds = field_params.split(',')
                tiling_source += '    {} {} = 0;\n'.format(fds[0].strip(), fds[1].strip())
            elif (line.startswith('END_TILING_DATA_DEF')):
                tiling_source += '};\n'
                tiling_source += '#pragma pack()\n\n'
                tiling_source += '#ifdef __NPU_TILING__\n'
                tiling_source += \
                    'inline [aicore] void Init{stru}(const __gm__ uint8_t* tiling, {stru}* const_data)\n'\
                        .format(stru=struct_def)
                tiling_source += '{\n'
                tiling_source += '    const __gm__ uint32_t *src = (const __gm__ uint32_t *)tiling;\n'
                tiling_source += '    uint32_t *dst = (uint32_t *)const_data;\n'
                tiling_source += '    for (auto i = 0; i < sizeof({}) / 4; i++) *(dst + i) = *(src + i);\n'\
                    .format(struct_def)
                tiling_source += '}\n'
                tiling_source += '#else\n'
                tiling_source += 'inline void Init{stru}(uint8_t* tiling, {stru}* const_data)\n'.format(stru=struct_def)
                tiling_source += '{\n'
                tiling_source += '    uint64_t *src = (uint64_t *)tiling;\n'
                tiling_source += '    uint64_t *dst = (uint64_t *)const_data;\n'
                tiling_source += '    for (auto i = 0; i < sizeof({}) / 8; i++) *(dst + i) = *(src + i);\n'\
                    .format(struct_def)
                tiling_source += '}\n'
                tiling_source += '#endif\n\n'
                end_source = '''
#undef GET_TILING_DATA
#define GET_TILING_DATA(tiling_data, tiling_arg) \\
{stru} tiling_data; \\
Init{stru}(tiling_arg, &tiling_data)\n
'''.format(stru=struct_def)
    tiling_source += end_source
    tiling_source += '#endif'
    with os.fdopen(os.open(tiling_file_out, const_var.WFLAGS, const_var.WMODES), 'w') as ofd:
        ofd.write(tiling_source)


if __name__ == '__main__':
    if len(sys.argv) <= 2:
        raise RuntimeError('arguments must greater than 2')
    gen_tiling(sys.argv[1], sys.argv[2])