#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.

import os
from dataclasses import dataclass
from typing import Optional, Any, TypeVar, Generic
from enum import Enum, auto

from .singleton import Singleton
from .utils import PrintBorder, ColorText
from .functional import uncurry
from .platform import DeviceName


T = TypeVar("T")


class EnvItemProp(Enum):
    REQUIRED = auto()
    FALLBACK = auto()
    DROP = auto()


@dataclass
class EnvItemAttr(Generic[T]):
    name: str
    prop: EnvItemProp
    default: Optional[T]


@dataclass
class EnvItem(Generic[T]):
    value: Optional[T]
    attr: EnvItemAttr[T]


@dataclass
class Environment:
    ascend_home_path: EnvItem[str]
    device_count: EnvItem[int]
    device_name: EnvItem[DeviceName]


class EnvChecker(metaclass=Singleton):
    def __init__(self, path):
        self._toolkit_path = path + "/ascend-toolkit/latest"
        self._env = self._check()

    @property
    def env(self):
        return self._env


    def _check(self) -> Optional[Environment]:
        _checks = [
            (EnvItemAttr("ASCEND_HOME_PATH", EnvItemProp.REQUIRED, None), self._check_ascend_home_path),
            (EnvItemAttr("Device count", EnvItemProp.DROP, None), self._check_device_count),
            (EnvItemAttr("Device name", EnvItemProp.FALLBACK, DeviceName.Ascend910B1), self._check_device_name),
        ]

        def _run_check(_attr: EnvItemAttr, _check) -> Optional[EnvItem]:
            print(f"{ColorText.run_test} check {_attr.name}")
            value = _check()
            if value is not None:
                print(f"{ColorText.run_ok} {_attr.name} check ok. value: {value}")
                return EnvItem(value, _attr)
            if _attr.prop is EnvItemProp.DROP:
                print(f"{ColorText.run_warn} {_attr.name} missing. some test cases will be dropped")
                return EnvItem(value, _attr)
            if _attr.prop is EnvItemProp.FALLBACK and _attr.default is not None:
                print(f"{ColorText.run_warn} {_attr.name} missing. fallback to {_attr.default}")
                return EnvItem(_attr.default, _attr)

            print(f"{ColorText.run_failed} {_attr.name} check failed")
            return None

        with PrintBorder("check environment", "check environment done"):
            args = tuple(map(uncurry(_run_check), _checks))
            if any(map(lambda x: x is None, args)):
                return None
            return Environment(*args)

    def _check_ascend_home_path(self) -> Optional[str]:
        path = os.getenv("ASCEND_HOME_PATH")
        return path

    def _check_device_count(self) -> Optional[int]:
        try:
            import torch
            import torch_npu
        except ModuleNotFoundError:
            return None
        return torch.npu.device_count()

    def _check_device_name(self) -> Optional[DeviceName]:
        return DeviceName.Ascend910B1