import logging
from typing import Callable, List, Optional, Union

from ..device import DeviceProfile
from .base import PerformanceModel
from .op_invoke_info import OpInvokeInfo


_op_estimator_table = {}

logger = logging.getLogger(__name__)


def register_op_estimator(op, device_names: Optional[Union[str, List[str]]], override: Optional[bool] = False):
    if not isinstance(device_names, (list, tuple)):
        device_names = [device_names]

    def decorator(estimator):
        for device_name in device_names:
            if device_name not in _op_estimator_table:
                _op_estimator_table[device_name] = {}
            if op in _op_estimator_table[device_name]:
                if override:
                    logger.warning(
                        "Overwriting existing estimator for op %s (device: %s)",
                        op,
                        device_name,
                    )
                else:
                    raise ValueError(f"Op {op} already registered for device {device_name}")
            _op_estimator_table[device_name][op] = estimator
        return estimator

    return decorator


def get_op_estimator(
    op, device_name: Optional[str]
) -> Callable[[OpInvokeInfo, DeviceProfile], PerformanceModel.Result]:
    if device_name not in _op_estimator_table:
        device_name = None
    if op not in _op_estimator_table[device_name]:
        op = None
    return _op_estimator_table[device_name][op]