import ast
import importlib.metadata as md
import logging
import os
import shutil
import subprocess
import sys
import vllm
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
TARGET_VLLM_VERSIONS = ("0.20.2", "0.21.0", "0.22.1", "0.23.0")
PATCH_SPECS = [
("config/load.py", "vllm_shuffle_load_config.patch"),
("model_executor/model_loader/default_loader.py", "vllm_shuffle_default_loader.patch"),
("model_executor/model_loader/weight_utils.py", "vllm_shuffle_weight_utils.patch"),
]
def should_apply_patch() -> bool:
"""Return True when the installed vLLM version should be patched."""
version = md.version("vllm")
if version.split("+")[0].split("-")[0] not in TARGET_VLLM_VERSIONS:
logger.info("Skip shuffle safetensors patch: vLLM %s is not in %s", version, TARGET_VLLM_VERSIONS)
return False
logger.info("Applying shuffle safetensors patch for vLLM %s", version)
return True
def is_shuffle_patched(path: str) -> bool:
"""Return True if the target file is already patched and remains valid Python."""
try:
with open(path, encoding="utf-8") as f:
content = f.read()
if "shuffle_safetensors_files" not in content:
return False
ast.parse(content)
return True
except (OSError, SyntaxError):
return False
def apply_patch(target_file: str, patch_file: str) -> bool:
"""Apply a patch to a single vLLM source file; skip if already patched."""
patch_bin = shutil.which("patch")
if not patch_bin:
logger.error("patch command not found in PATH")
return False
result = subprocess.run(
[patch_bin, "-p0", "--fuzz=500", "--ignore-whitespace", target_file, patch_file],
capture_output=True,
text=True,
check=False,
)
if result.returncode == 0:
logger.info("Patch applied successfully to %s", target_file)
return True
if is_shuffle_patched(target_file):
logger.info("Already patched: %s", target_file)
return True
logger.error("Failed to apply patch to %s\n%s", target_file, result.stderr.strip())
return False
def main() -> int:
"""Apply all patches in PATCH_SPECS; return 0 on success or skip, 1 on failure."""
if not should_apply_patch():
return 0
script_dir = os.path.dirname(os.path.abspath(__file__))
version = md.version("vllm").split("+")[0].split("-")[0]
patch_dir = os.path.join(script_dir, version)
vllm_root = vllm.__path__[0]
failed = 0
for rel_path, patch_name in PATCH_SPECS:
if not apply_patch(os.path.join(vllm_root, rel_path), os.path.join(patch_dir, patch_name)):
failed += 1
return 1 if failed else 0
if __name__ == "__main__":
sys.exit(main())