import os
import subprocess
import sys
import sysconfig
from logging import raiseExceptions
from setuptools import setup, find_packages
from setuptools.command.build_py import build_py
MODEL_SELECT = "MODEL_SELECT"
DETERMINISTIC = "DETERMINISTIC"
def inject_seed_code_to_fsdp_workers(verl_path):
target_file = os.path.join(verl_path, "verl", "workers", "fsdp_workers.py")
if not os.path.exists(target_file):
print(f"Error: fsdp_workers.py not found at {target_file}")
return False
seed_code = '''import random
import numpy as np
import torch
import torch_npu
def seed_all(seed=1234):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['HCCL_DETERMINISTIC'] = str(True)
os.environ['LCCL_DETERMINISTIC'] = str(1)
os.environ['CLOSE_MATMUL_K_SHIFT'] = str(1)
os.environ['ATB_LLM_LCOC_ENABLE'] = "0"
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)
torch_npu.npu.manual_seed_all(seed)
torch_npu.npu.manual_seed(seed)
seed_all()
'''
try:
with open(target_file, 'r') as f:
lines = f.readlines()
except Exception as e:
print(f"Error reading {target_file}: {e}")
return False
insert_index = 97
if len(lines) < insert_index + 1:
while len(lines) < insert_index:
lines.append('\n')
lines.append(seed_code)
else:
lines.insert(insert_index, seed_code)
try:
with open(target_file, 'w') as f:
f.writelines(lines)
print(f"Successfully injected seed_all code at line 98 of {target_file}")
return True
except Exception as e:
print(f"Error writing to {target_file}: {e}")
return False
def inject_verl_plugin(custom_path=None):
"""将NPU加速支持注入到verl包中"""
print("Starting verl plugin injection...")
if 'VERL_PATH' in os.environ:
verl_path = os.path.join(os.environ['VERL_PATH'], "verl")
print(f"Using verl path from environment variable: {verl_path}")
elif custom_path:
verl_path = custom_path
print(f"Using custom verl path: {verl_path}")
else:
print("Searching for verl package automatically...")
paths_to_try = [
sysconfig.get_paths()["purelib"],
sysconfig.get_paths()["platlib"],
] + sys.path
verl_path = None
for path in paths_to_try:
if not path:
continue
candidate = os.path.join(path, "verl")
if os.path.exists(candidate) and os.path.isdir(candidate):
verl_path = candidate
break
if not verl_path:
try:
result = subprocess.run(
[sys.executable, "-m", "pip", "show", "verl"],
capture_output=True,
text=True,
check=True
)
for line in result.stdout.splitlines():
if line.startswith("Location:"):
verl_path = os.path.join(line.split(": ")[1], "verl")
break
except (subprocess.CalledProcessError, FileNotFoundError) as e:
print(f"pip show failed: {e}")
if not verl_path:
print("Error: verl package not found. Please specify with VERL_PATH environment variable.")
return False
print(f"Found verl at: {verl_path}")
init_modify_success = modify_init_fun(verl_path)
qwen3vl_modify_success = True
model_select = os.environ.get(MODEL_SELECT, None)
if model_select and model_select == "Qwen3vl":
qwen3vl_modify_success = qwen3vl_fun_modify(verl_path)
return init_modify_success and qwen3vl_modify_success
def qwen3vl_fun_modify(verl_path) -> bool:
npu_patch_import_content = """
if get_version("transformers") > "4.57.1":
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
else:
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
"""
npu_patch_to_change = "from transformers.modeling_utils import PretrainedConfig, PreTrainedModel"
npu_patch_success = modify_fun_common(verl_path, "models/transformers/npu_patch.py", npu_patch_import_content,
npu_patch_to_change)
padding_workers_import_content = """
if "padding_mode" not in self.config.engine_kwargs:
pass
elif self.config.engine_kwargs.get('padding_mode', 0) == 1:
response_attention_mask = torch.ones([attention_mask.shape[0], 1024], dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
"""
padding_workers_to_change = "attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)"
padding_workers_success = modify_fun_common(verl_path, "workers/rollout/vllm_rollout/vllm_rollout_spmd.py",
padding_workers_import_content, padding_workers_to_change)
padding_trainer_import_content = """
batch: DataProto = DataProto.from_single_dict(batch_dict)
if "padding_mode" not in self.config.data:
print("DEBUG: Padding mode not configured, skipping attention mask modification")
pass
elif self.config.data.padding_mode == 1:
batch.batch['attention_mask'] = torch.ones_like(batch.batch['attention_mask'])
print("INFO: Padding sequences to 17408 tokens (16 * 1024+1024) for alignment")
else:
print("DEBUG: Other padding mode specified, no additional processing required")
pass
"""
padding_trainer_to_change = "batch: DataProto = DataProto.from_single_dict(batch_dict)"
padding_trainer = modify_fun_common(verl_path, "trainer/ppo/ray_trainer.py", padding_trainer_import_content,
padding_trainer_to_change)
return npu_patch_success and padding_workers_success and padding_trainer
def modify_fun_common(verl_path, file_path, import_content, line_to_change):
modify_file = os.path.join(verl_path, file_path)
if not os.path.exists(modify_file):
print(f"Error: verl initialization file not found: {modify_file}")
return False
try:
with open(modify_file, "r") as f:
content = f.read()
except Exception as e:
print(f"Error reading {modify_file}: {e}")
return False
if import_content in content:
print(f"Info: {import_content} already contains NPU acceleration import")
else:
try:
with open(modify_file, "r") as f:
lines = f.readlines()
modified = False
new_lines = []
for line in lines:
if line.strip() == line_to_change:
new_lines.append(import_content)
print(f"Changed out line in {modify_file}: {line.strip()}")
modified = True
else:
new_lines.append(line)
if modified:
with open(modify_file, "w") as f:
f.writelines(new_lines)
print(f"Successfully modified {modify_file}")
except Exception as e:
print(f"Error modifying {modify_file}: {e}")
return False
return True
return True
def modify_init_fun(verl_path):
init_file = os.path.join(verl_path, "__init__.py")
if not os.path.exists(init_file):
print(f"Error: verl initialization file not found: {init_file}")
return False
import_content = """
# NPU acceleration support added by mindspeed-mm plugin
from verl.utils.device import is_npu_available
if is_npu_available:
import verl_npu
print("NPU acceleration enabled for verl")
"""
try:
with open(init_file, "r") as f:
content = f.read()
except Exception as e:
print(f"Error reading {init_file}: {e}")
return False
if import_content in content:
print(f"Info: {init_file} already contains NPU acceleration import")
else:
try:
with open(init_file, "a") as f:
f.write(import_content)
print(f"Successfully modified {init_file} to add NPU acceleration support")
except Exception as e:
print(f"Error writing to {init_file}: {e}")
return False
return True
def inject_vllm_plugin():
print("Searching for vllm ascend package automatically...")
vllm_path = get_vllm_path()
if not vllm_path:
print("Error: vllm_ascend package not found. Please specify with VLLM_PATH environment variable.")
return False
print(f"Found vllm_ascend at: {vllm_path}")
rotary_embedding_file = os.path.join(vllm_path, "ops", "rotary_embedding.py")
if not os.path.exists(rotary_embedding_file):
print(f"Warning: rotary_embedding file not found: {rotary_embedding_file}")
return True
line_to_change = "query, key = torch_npu.npu_mrope(positions,"
line_change_to = " query, key = torch_npu.npu_mrope(positions.contiguous(),\n"
try:
with open(rotary_embedding_file, "r") as f:
lines = f.readlines()
modified = False
new_lines = []
for line in lines:
if line.strip() == line_to_change:
new_lines.append(line_change_to)
print(f"Changed out line in {rotary_embedding_file}: {line.strip()}")
modified = True
else:
new_lines.append(line)
if modified:
with open(rotary_embedding_file, "w") as f:
f.writelines(new_lines)
print(f"Successfully modified {rotary_embedding_file}")
else:
already_changed = any(line_change_to in line for line in lines)
if already_changed:
print(f"Info: line already changed in {rotary_embedding_file}")
else:
print(f"Warning: line to change not found in {rotary_embedding_file}: {line_to_change}")
except Exception as e:
print(f"Error modifying {rotary_embedding_file}: {e}")
return False
return True
def get_vllm_path():
"""尝试多种方式查找vllm安装路径"""
paths_to_try = [
sysconfig.get_paths()["purelib"],
sysconfig.get_paths()["platlib"],
] + sys.path
vllm_path = None
for path in paths_to_try:
if not path:
continue
candidate = os.path.join(path, "vllm_ascend")
if os.path.exists(candidate) and os.path.isdir(candidate):
vllm_path = candidate
break
if not vllm_path:
try:
result = subprocess.run(
[sys.executable, "-m", "pip", "show", "vllm_ascend"],
capture_output=True,
text=True,
check=True
)
for line in result.stdout.splitlines():
if line.startswith("Editable project location:"):
vllm_path = os.path.join(line.split(": ")[1], "vllm_ascend")
break
except (subprocess.CalledProcessError, FileNotFoundError) as e:
print(f"pip show failed: {e}")
if not vllm_path:
import vllm_ascend
vllm_path = vllm_ascend.__path__[0]
return vllm_path
class CustomBuildPy(build_py):
def run(self):
super().run()
model_select = os.environ.get(MODEL_SELECT, None)
deterministic_select = os.environ.get(DETERMINISTIC, None)
if model_select is None:
print("Error: Environment variable 'MODEL_SELECT' is required. Please set MODEL_SELECT to specify the model.")
custom_path = os.environ.get('VERL_PATH', None)
if not inject_verl_plugin(custom_path):
print("Error: verl injection failed. Please check installation.")
if model_select == "Qwen2_5vl" and not inject_vllm_plugin():
print("Error: vllm injection failed. Please check installation.")
if deterministic_select is not None:
success = inject_seed_code_to_fsdp_workers(custom_path)
if success:
print("Deteministic is enabled")
else:
print("Failed to enable the deterministic")
else:
print("Deteministic is not enabled")
setup(
name="verl_npu",
version="0.0.1",
license="Apache 2.0",
description="verl npu backend plugin",
packages=find_packages(include=["verl_npu"]),
classifiers=[
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
],
python_requires=">=3.9",
cmdclass={
"build_py": CustomBuildPy,
}
)