__all__ = [
"npu_combine_tensors",
"get_part_combined_tensor",
"is_combined_tensor_valid",
"FlopsCounter",
"set_thread_affinity",
"reset_thread_affinity",
"save_async",
"get_cann_version",
]
from torch_npu.npu.utils import get_cann_version
from ._inductor import _max_unpoolnd_patch
from .affinity import (
_reset_thread_affinity as reset_thread_affinity,
_set_thread_affinity as set_thread_affinity,
)
from .asd_detector import register_asd_hook, set_asd_loss_scale
from .combine_tensors import (
get_part_combined_tensor,
is_combined_tensor_valid,
npu_combine_tensors,
)
from .flops_count import _FlopsCounter as FlopsCounter
from .serialization import save_async