#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# Copyright (c) 2025 Huawe                                                                                                                                                                                                                     i Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        aa
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------

import torch
import numpy as np

from atk.configs.dataset_config import InputDataset
from atk.tasks.api_execute import register
from atk.tasks.api_execute.base_api import BaseApi

@register("ascend_aclnn_signbitspack")
class AclnnSigbBitsPack(BaseApi):

    def __call__(self, input_data: InputDataset, with_output: bool = False):
        selfTensor = input_data.kwargs["self"]
        size = input_data.kwargs["size"]
        # 将float列表转换为numpy数组以便进行位操作等处理
        float_array = selfTensor.numpy()
        # 获取符号位
        sign_bits = np.sign(float_array).astype(np.int8)
        # 这里简单地将除符号位外的其他位都设为0(假设这就是你想要的1位Adam表示方式,实际可能需根据具体需求调整)
        one_bit_adam = np.where(sign_bits >= 0, 1, 0)
        num_elements = len(one_bit_adam)
        remainder = num_elements % 8
        if remainder!= 0:
            # 需要填充的元素个数
            padding_size = 8 - remainder
            # 创建填充的1位Adam值,这里对于 -1按照之前假设为1
            padding_values = np.full(padding_size, 0, dtype=np.uint8)
            one_bit_adam = np.concatenate((one_bit_adam, padding_values), axis=0)
        # 将1位Adam值重新整形为二维数组,每行8个元素
        one_bit_adam_reshaped = one_bit_adam.reshape(-1, 8)
        # 用于存储打包后的uint8值的列表
        packed_uint8_list = []
        for row in one_bit_adam_reshaped:
            # 创建一个8位的二进制数,初始值为0
            binary_value = 0
            for i, bit in enumerate(row):
                # 将每个1位Adam值按照从左到右(对应二进制低位到高位)的顺序设置到二进制数中
                binary_value += bit << i
            # 将二进制数转换为uint8类型并添加到列表中
            packed_uint8_list.append(binary_value)
        # 将列表转换为numpy数组并返回
        packed_uint8_array = np.array(packed_uint8_list, dtype=np.uint8)
        num_packed = len(packed_uint8_array)
        reshaped_size = num_packed // size
        # 将打包后的uint8数组转化为二维Tensor
        tensor_result = torch.from_numpy(packed_uint8_array).reshape(size, reshaped_size)
        return tensor_result