__all__ = [
    "RoIPointPool3d",
    "SparseConv3d",
    "SparseInverseConv3d",
    "SubMConv3d",
    "SparseConvolution",
    "SparseConvTensor",
    "SparseModule",
    "SparseSequential",
    "Voxelization",
    "assign_score_withk",
    "bev_pool",
    "bev_pool_v2",
    "bev_pool_v3",
    "border_align",
    "box_iou_quadri",
    "box_iou_rotated",
    "boxes_overlap_bev",
    "npu_boxes_overlap_bev",
    "boxes_iou_bev",
    "deform_conv2d",
    "DeformConv2dFunction",
    "dynamic_scatter",
    "furthest_point_sampling",
    "furthest_point_sample_with_dist",
    "npu_fused_bias_leaky_relu",
    "geometric_kernel_attention",
    "grid_sampler2d_v2",
    "grid_sampler3d_v1",
    "group_points",
    "hypot",
    "knn",
    "modulated_deform_conv2d",
    "ModulatedDeformConv2dFunction",
    "multi_scale_deformable_attn",
    "npu_index_select",
    "npu_multi_scale_deformable_attn_function",
    "nms3d_normal",
    "npu_add_relu",
    "npu_deformable_aggregation",
    "npu_batch_matmul",
    "deformable_aggregation",
    "npu_dynamic_scatter",
    "npu_max_pool2d",
    "nms3d",
    "MultiScaleDeformableAttnFunction",
    "npu_points_in_box",
    "npu_points_in_box_all",
    "points_in_box",
    "points_in_boxes_all",
    "pixel_group",
    "roi_align_rotated",
    "roiaware_pool3d",
    "roipoint_pool3d",
    "npu_rotated_iou",
    "npu_rotated_overlaps",
    "scatter_max",
    "scatter_mean",
    "scatter_add",
    "three_interpolate",
    "three_nn",
    "npu_voxel_pooling_train",
    "voxelization",
    "unique_voxel",
    "cal_anchors_heading",
    "npu_gaussian",
    "npu_draw_gaussian_to_heatmap",
    "npu_assign_target_of_single_head",
    "diff_iou_rotated_2d",
    "nms3d_on_sight",
    "cartesian_to_frenet",
    "radius",
    "npu_unique",
    "graph_softmax",
    "cylinder_query",
    "sigmoid_focal_loss",
    "npu_fake_tensor_quant",
    "npu_fake_tensor_quant_inplace",
    "npu_fake_tensor_quant_with_axis",
    "default_patcher_builder",
    "patch_mmcv_version",
]

# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║                         Built-in Dependencies                              ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
# Pre-import torch and torch_npu to ensure they are available for mx_driving.
# This prevents errors when users forget to import these modules before using mx_driving.
import torch  # noqa: F401
import torch_npu  # noqa: F401
from torch_npu.contrib import transfer_to_npu  # noqa: F401


import os
import warnings

from .get_chip_info import Dsmi_dc_Func

import mx_driving._C

from .modules.roi_point_pool_3d import RoIPointPool3d
from .modules.sparse_conv import SparseConv3d, SubMConv3d, SparseInverseConv3d, SparseConvolution
from .modules.sparse_modules import SparseConvTensor, SparseModule, SparseSequential
from .modules.voxelization import Voxelization
from .ops.assign_score_withk import assign_score_withk
from .ops.bev_pool import bev_pool
from .ops.bev_pool_v2 import bev_pool_v2
from .ops.bev_pool_v3 import bev_pool_v3
from .ops.border_align import border_align
from .ops.box_iou import box_iou_quadri, box_iou_rotated
from .ops.boxes_overlap_bev import boxes_overlap_bev, npu_boxes_overlap_bev, boxes_iou_bev
from .ops.deform_conv2d import DeformConv2dFunction, deform_conv2d
from .ops.furthest_point_sampling import furthest_point_sampling
from .ops.furthest_point_sampling_with_dist import furthest_point_sample_with_dist
from .ops.fused_bias_leaky_relu import npu_fused_bias_leaky_relu
from .ops.group_points import group_points
from .ops.geometric_kernel_attention import geometric_kernel_attention
from .ops.grid_sampler2d_v2 import grid_sampler2d_v2
from .ops.grid_sampler3d_v1 import grid_sampler3d_v1
from .ops.hypot import hypot
from .ops.knn import knn
from .ops.modulated_deform_conv2d import ModulatedDeformConv2dFunction, modulated_deform_conv2d
from .ops.multi_scale_deformable_attn import (
    MultiScaleDeformableAttnFunction,
    multi_scale_deformable_attn,
    npu_multi_scale_deformable_attn_function,
)
from .ops.nms3d_normal import nms3d_normal
from .ops.npu_add_relu import npu_add_relu
from .ops.npu_deformable_aggregation import npu_deformable_aggregation, deformable_aggregation
from .ops.npu_dynamic_scatter import npu_dynamic_scatter, dynamic_scatter
from .ops.npu_max_pool2d import npu_max_pool2d
from .ops.nms3d import nms3d
from .ops.npu_points_in_box import npu_points_in_box, points_in_box
from .ops.npu_points_in_box_all import npu_points_in_box_all, points_in_boxes_all
from .ops.npu_index_select import npu_index_select
from .ops.pixel_group import pixel_group
from .ops.roi_align_rotated import roi_align_rotated
from .ops.roiaware_pool3d import roiaware_pool3d
from .ops.roipoint_pool3d import roipoint_pool3d
from .ops.rotated_iou import npu_rotated_iou
from .ops.rotated_overlaps import npu_rotated_overlaps
from .ops.scatter_max import scatter_max
from .ops.scatter_mean import scatter_mean
from .ops.scatter_add import scatter_add
from .ops.three_interpolate import three_interpolate
from .ops.three_nn import three_nn
from .ops.voxel_pooling_train import npu_voxel_pooling_train
from .ops.voxelization import voxelization
from .ops.unique_voxel import unique_voxel
from .ops.cal_anchors_heading import cal_anchors_heading
from .ops.npu_gaussian import npu_gaussian
from .ops.npu_draw_gaussian_to_heatmap import npu_draw_gaussian_to_heatmap
from .ops.npu_assign_target_of_single_head import npu_assign_target_of_single_head
from .ops.diff_iou_rotated import diff_iou_rotated_2d
from .ops.npu_batch_matmul import npu_batch_matmul
from .ops.nms3d_on_sight import nms3d_on_sight
from .ops.cartesian_to_frenet import cartesian_to_frenet
from .ops.radius import radius
from .ops.npu_unique import npu_unique
from .ops.graph_softmax import graph_softmax
from .ops.cylinder_query import cylinder_query
from .ops.sigmoid_focal_loss import sigmoid_focal_loss
from .patcher import default_patcher_builder, patch_mmcv_version
from .ops.npu_fake_tensor_quant import npu_fake_tensor_quant
from .ops.npu_fake_tensor_quant import npu_fake_tensor_quant_inplace
from .ops.npu_fake_tensor_quant import npu_fake_tensor_quant_with_axis


def _set_env():
    mx_driving_root = os.path.dirname(os.path.abspath(__file__))

    # 默认配置
    customize = "customize"
    opapi_name = "libcust_opapi.so"

    try:
        dsmi = Dsmi_dc_Func()
        soc_version = dsmi.chip_version_h()
        # 判断是否为A5
        if soc_version and ("95" in soc_version or soc_version == "Ascend950PR"):
            customize = "customize_arch35"
            opapi_name = "libcust_opapi_arch35.so"
    except Exception:
        warnings.warn("Failed to get chip version, falling back to default logic.")

    # 根据 customize 选择路径
    mx_driving_opp_path = os.path.join(mx_driving_root, "packages", "vendors", customize)

    ascend_custom_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH")
    if ascend_custom_opp_path:
        new_path = mx_driving_opp_path + ":" + ascend_custom_opp_path
    else:
        new_path = mx_driving_opp_path
    os.environ["ASCEND_CUSTOM_OPP_PATH"] = new_path

    mx_driving_op_api_so_path = os.path.join(mx_driving_opp_path, "op_api", "lib", opapi_name)
    mx_driving._C._init_op_api_so_path(mx_driving_op_api_so_path)


_set_env()