import os
from functools import wraps
from logging import getLogger

import torch
import torch_npu

LOG = getLogger(__name__)


def extend_seed_all(seed=1234):
    """ set npu deterministic mode
    'HCCL_DETERMINISTIC' is a deterministic switch in ops level, set it to 'True' to enable ops level deterministic, set it to 'False' to disable ops level deterministic.
    'CLOSE_MATMUL_K_SHIFT' is a switch of matmul K-axis shift, set it to '1' to close matmul K-axis shift, set it to '0' to enable matmul K-axis shift.
    'PYTHONHASHSEED' refers to python hash seed, use a string of non-negative integer to specify the seed.
    """
    os.environ['HCCL_DETERMINISTIC'] = 'True'
    os.environ['CLOSE_MATMUL_K_SHIFT'] = '1'
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.use_deterministic_algorithms(True)
    torch_npu.npu.manual_seed_all(seed)
    torch_npu.npu.manual_seed(seed)


def npu_deterministic_wrapper(fn):
    @wraps(fn)
    def wrapper(seed, *args, **kwargs):
        fn(seed, *args, **kwargs)
        extend_seed_all(seed)
        LOG.info("Deterministic computing is applied for npu.")
    return wrapper