6cffd44a创建于 2025年3月17日历史提交
diff --git a/setup.py b/setup.py
index 46328c3f..0b952e90 100644
--- a/setup.py
+++ b/setup.py
@@ -217,7 +217,7 @@ def get_extensions():
         include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
         cuda_args = os.getenv('MMCV_CUDA_ARGS')
         extra_compile_args = {
-            'nvcc': [cuda_args, '-std=c++14'] if cuda_args else ['-std=c++14'],
+            'nvcc': [cuda_args, '-std=c++14', 'std=c++17'] if cuda_args else ['-std=c++14', '-std=c++17'],
             'cxx': ['-std=c++14'],
         }
         if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
@@ -267,15 +267,7 @@ def get_extensions():
         # to compile those cpp files, so there is no need to add the
         # argument
         if platform.system() != 'Windows':
-            if parse_version(torch.__version__) <= parse_version('1.12.1'):
-                extra_compile_args['cxx'] = ['-std=c++14']
-            else:
-                extra_compile_args['cxx'] = ['-std=c++17']
-        else:
-            if parse_version(torch.__version__) <= parse_version('1.12.1'):
-                extra_compile_args['cxx'] = ['/std:c++14']
-            else:
-                extra_compile_args['cxx'] = ['/std:c++17']
+            extra_compile_args['cxx'] = ['-std=c++14', '-std=c++17']

         include_dirs = []

@@ -424,33 +416,21 @@ def get_extensions():
             extra_compile_args['cxx'] += ['-ObjC++']
             # src
             op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
-                glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp')
-            # TODO: support mps ops on torch>=2.1.0
-            if parse_version(torch.__version__) < parse_version('2.1.0'):
-                op_files += glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \
-                    glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm')
+                glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
+                glob.glob('./mmcv/ops/csrc/common/mps/*.mm') + \
+                glob.glob('./mmcv/ops/csrc/pytorch/mps/*.mm')
             extension = CppExtension
             include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
             include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps'))
         elif (os.getenv('FORCE_NPU', '0') == '1'):
             print(f'Compiling {ext_name} only with CPU and NPU')
             try:
-                import importlib
-
                 from torch_npu.utils.cpp_extension import NpuExtension
-                extra_compile_args['cxx'] += [
-                    '-D__FILENAME__=\"$$(notdir $$(abspath $$<))\"'
-                ]
-                extra_compile_args['cxx'] += [
-                    '-I' + importlib.util.find_spec(
-                        'torch_npu').submodule_search_locations[0] +
-                    '/include/third_party/acl/inc'
-                ]
                 define_macros += [('MMCV_WITH_NPU', None)]
                 extension = NpuExtension
-                if parse_version(torch.__version__) < parse_version('2.1.0'):
+                if parse_version(torch.__version__) <= parse_version('2.0.0'):
                     define_macros += [('MMCV_WITH_XLA', None)]
-                if parse_version(torch.__version__) >= parse_version('2.1.0'):
+                if parse_version(torch.__version__) > parse_version('2.0.0'):
                     define_macros += [('MMCV_WITH_KPRIVATE', None)]
             except Exception:
                 raise ImportError('can not find any torch_npu')
@@ -476,10 +456,7 @@ def get_extensions():
         # to compile those cpp files, so there is no need to add the
         # argument
         if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
-            if parse_version(torch.__version__) <= parse_version('1.12.1'):
-                extra_compile_args['nvcc'] += ['-std=c++14']
-            else:
-                extra_compile_args['nvcc'] += ['-std=c++17']
+            extra_compile_args['nvcc'] += ['-std=c++14', '-std=c++17']

         ext_ops = extension(
             name=ext_name,