#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from enum import Enum
import os

import numpy as np

ASCEND_GLOBAL_HASHTABLE_COLLECTION = "ASCEND_GLOBAL_HASHTABLE_COLLECTION"
ASCEND_CUTTING_POINT_INITIALIZER = "ASCEND_CUTTING_POINT_INITIALIZER"
ASCEND_SPARSE_LOOKUP_ENTRANCE = "ASCEND_SPARSE_LOOKUP_ENTRANCE"
ASCEND_SPARSE_LOOKUP_ID_OFFSET = "ASCEND_SPARSE_LOOKUP_ID_OFFSET"
ASCEND_TIMESTAMP = "ASCEND_TIMESTAMP"
ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB"
EMPTY_STR = ""

# Bytes
FLOAT32_BYTES = 4
UINT64_BYTES = 8
UINT32_BYTES = 4

# 获取ConfigInitializer对象实例失败提示信息
GET_CONFIG_INSTANCE_ERR_MSG = "Please init the environment for mx_rec at first."

# Used for slicer finding the orphan lookup key.
ORPHAN_LOOKUP_KEY_PREFIX = "orphan"

# the name of the embedding table merged by third party
ASCEND_TABLE_NAME_MUST_CONTAIN = None

# while循环最大深度
MAX_WHILE_SIZE = 800

# acl通道数据深度
DEFAULT_HD_CHANNEL_SIZE = 40
MAX_HD_CHANNEL_SIZE = 8192
MIN_HD_CHANNEL_SIZE = 2

# CM_WORKER_SIZE集群节点数
DEFAULT_CM_WORKER_SIZE = 0
DEFAULT_CM_CHIEF_DEVICE = 0
MAX_CM_WORKER_SIZE = 512
MIN_CM_WORKER_SIZE = 0

# key process线程数
DEFAULT_KP_THREAD_NUM = 6
MIN_KP_THREAD_NUM = 1
MAX_KP_THREAD_NUM = 10

# Fast unique去重最大线程数
DEFAULT_FAST_UNIQUE_THREAD_NUM = 8
MIN_FAST_UNIQUE_THREAD_NUM = 1
MAX_FAST_UNIQUE_THREAD_NUM = 8

# Hot Embedding更新步数
DEFAULT_HOT_EMB_UPDATE_STEP = 1000
MIN_HOT_EMB_UPDATE_STEP = 1
MAX_HOT_EMB_UPDATE_STEP = 1000

MULTI_LOOKUP_TIMES = 128
DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24
TRAIN_CHANNEL_ID = 0
EVAL_CHANNEL_ID = 1
HASHTABLE_COLLECTION_NAME_LENGTH = 30
MAX_VOCABULARY_SIZE = 10**9
MAX_DEVICE_VOCABULARY_SIZE = 10 ** 9

# Permission for saving.
SAVE_FILE_MODE = 0o640
SAVE_DIR_MODE = 0o750
SAVE_FILE_FLAG = os.O_WRONLY | os.O_CREAT


# can't move to saver.constant, otherwise will cause circle import
class SsdCompactLevel(Enum):
    NO_COMPACT = 0
    PARTIAL_COMPACT = 1
    FULL_COMPACT = 2

# incremental checkpoint related
BASE_MODEL = "base"
DELTA_MODEL = "delta"

# HDFS file system's file prefix
HDFS_FILE_PREFIX = ["viewfs://", "hdfs://"]

# so包名称
LIBREC_TF_NPU_OPS_SO = "librecsdk_tf_npu_ops.so"
LIBREC_TF_REC_V1_CPU_SO = "librecsdk_tf_rec_v1_cpu_ops.so"
LIBREC_EOS_OPS_SO = "librec_eos_ops.so"

INVALID_CHARS = frozenset({
    "\n", "\f", "\r", "\b", "\t", "\v",
    "\u000D", "\u000A", "\u000C", "\u000B", "\u0009",
    "\u0008", "\u007F"
})


class BaseEnum(Enum):
    @classmethod
    def mapping(cls, key):
        for mode in cls:
            if isinstance(key, BaseEnum):
                key_value = key.value
            else:
                key_value = key
            if key_value == mode.value:
                return mode

        raise KeyError(f"Cannot find a corresponding mode in current Enum "
                       f"class {cls}, given parameter '{key}[{key.__class__}]' is illegal, "
                       f"please choose a valid one from "
                       f"'{list(map(lambda c: c.value, cls))}'.")


class EnvOption(Enum):
    MXREC_LOG_LEVEL = "MXREC_LOG_LEVEL"
    ACL_TIMEOUT = "AclTimeout"
    HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE"
    KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM"
    MAX_UNIQUE_THREAD_NUM = "MAX_UNIQUE_THREAD_NUM"
    FAST_UNIQUE = "FAST_UNIQUE"
    HOT_EMB_UPDATE_STEP = "HOT_EMB_UPDATE_STEP"
    GLOG_STDERRTHREAHOLD = "GLOG_stderrthreshold"
    USE_COMBINE_FAAE = "USE_COMBINE_FAAE"
    RECORD_KEY_COUNT = "RECORD_KEY_COUNT"
    SSD_SAVE_COMPACT_LEVEL = "SSD_SAVE_COMPACT_LEVEL"
    USE_SHM_SWAP = "USE_SHM_SWAP"


class DataName(Enum):
    KEY = "key"
    EMBEDDING = "embedding"
    FEATURE_MAPPING = "feature_mapping"
    OFFSET = "offset"
    THRESHOLD = "threshold"
    VALID_LEN = "valid_len"
    VALID_BUCKET_NUM = "valid_bucket_num"


class DataAttr(Enum):
    SHAPE = "shape"
    DATATYPE = "data_type"


class ASCAnchorAttr(Enum):
    TABLE_INSTANCE = "table_instance"
    IS_TRAINING = "is_training"
    RESTORE_VECTOR = "restore_vector"
    ID_OFFSETS = "id_offsets"
    FEATURE_SPEC = "feature_spec"
    ALL2ALL_MATRIX = "all2all_matrix"
    HOT_POS = "hot_pos"
    LOOKUP_RESULT = "lookup_result"
    MOCK_LOOKUP_RESULT = "mock_lookup_result"
    RESTORE_VECTOR_SECOND = "restore_vector_second"
    UNIQUE_KEYS = "unique_keys"
    IS_GRAD = "is_grad"
    TABLE_NAME = "table_name"
    CHANNEL_ID = "channel_id"


class OptimizerType(Enum):
    LAZY_ADAM = "LazyAdam"
    SGD = "SGD"

    @staticmethod
    def get_optimizer_state_meta(mode):
        if mode in OPTIMIZER_STATE_META:
            return OPTIMIZER_STATE_META.get(mode)

        raise ValueError(f"Invalid mode value, please choose one from {list(map(lambda c: c.value, OptimizerType))}")


OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], OptimizerType.SGD: []}


class All2allGradientsOp(BaseEnum):
    SUM_GRADIENTS = "sum_gradients"
    SUM_GRADIENTS_AND_DIV_BY_RANKSIZE = "sum_gradients_and_div_by_ranksize"


class RecCPPLogLevel(Enum):
    TRACE = "-2"
    DEBUG = "-1"
    INFO = "0"
    WARN = "1"
    ERROR = "2"


class Flag(Enum):
    TRUE = "1"
    FALSE = "0"