#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#

import contextlib
import copy
import functools
import gc
import json
import logging
import multiprocessing
import os
import shlex
import subprocess
import sys
import threading
import time
import traceback
from pathlib import Path
from typing import Any, TypeVar

import huggingface_hub
import numpy as np
import openai
import psutil
import pytest
import requests
import torch
from modelscope import snapshot_download  # type: ignore[import-untyped]
from PIL import Image
from requests.exceptions import RequestException
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
from vllm.inputs import TextPrompt
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.network_utils import get_open_port

from tests.e2e.model_utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
from tests.e2e.nightly.multi_node.scripts.multi_node_config import DisaggregatedPrefillCfg, NodeInfo
from vllm_ascend.ascend_config import clear_ascend_config

# TODO: remove this part after the patch merged into vllm, if
# we not explicitly patch here, some of them might be effectiveless
# in pytest scenario
from vllm_ascend.utils import adapt_patch  # noqa E402

adapt_patch(True)
adapt_patch(False)

from vllm.distributed.parallel_state import (  # noqa E402
    destroy_distributed_environment,
    destroy_model_parallel,
)

_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")

_PromptMultiModalInput = list[_M] | list[list[_M]]

PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]


logger = logging.getLogger(__name__)

_TEST_DIR = os.path.dirname(__file__)
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "long_prompt.txt")]

DISAGG_EPD_PROXY_SCRIPT = (
    Path(__file__).parent.parent.parent / "examples" / "disaggregated_encoder" / "disagg_epd_proxy.py"
)
DISAGG_PD_PROXY_SCRIPT = (
    Path(__file__).parent.parent.parent
    / "examples"
    / "disaggregated_prefill_v1"
    / "load_balance_proxy_server_example.py"
)


def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
    # We can try to clean up memory in this subprocess, though it mostly affects this process.
    # But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
    gc.collect()
    torch.npu.empty_cache()

    _, total_npu_memory = torch.npu.mem_get_info()
    start_time = time.time()

    while True:
        free_bytes, _ = torch.npu.mem_get_info()
        if free_bytes / total_npu_memory >= target_free_percentage:
            print("check_npu_memory_worker: npu free memory decreased target value.")
            return  # Success

        elapsed = time.time() - start_time
        if elapsed > max_wait_seconds:
            # Print to stderr so it's visible in test logs even if captured
            print(
                f"Timeout: NPU memory free size did not reach "
                f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
                file=sys.stderr,
            )
            sys.exit(1)  # Failure

        print(
            f"Waiting for NPU memory to be free: "
            f"{free_bytes / 1024**3:.2f} GB available, "
            f"Elapsed time: {elapsed:.2f} s."
        )
        # Try to clean up
        gc.collect()
        torch.npu.empty_cache()
        time.sleep(1)


def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
    """Decorator to wait until the NPU memory free size is above target_free_percentage.

    Args:
        target_free_percentage (float): Target free memory percentage of total.
        max_wait_seconds (float): Maximum wait time in seconds.
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Clean up non-NPU resources in the main process
            cleanup_dist_env_and_memory()

            # Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
            ctx = multiprocessing.get_context("spawn")
            p = ctx.Process(target=_check_npu_memory_worker, args=(target_free_percentage, max_wait_seconds))
            p.start()
            p.join()

            if p.exitcode != 0:
                raise TimeoutError(
                    f"Timeout: NPU memory free size did not reach "
                    f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
                )

            return func(*args, **kwargs)

        return wrapper

    return decorator


def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray

        ray.shutdown()
    gc.collect()

    # Only clean NPU cache if NPU is already initialized/available in this process.
    # This prevents accidental initialization of NPU context in the main process,
    # which would break subsequent forks.
    if hasattr(torch, "npu") and torch.npu.is_initialized():
        torch.npu.empty_cache()
        torch.npu.reset_peak_memory_stats()


class MooncakeLauncher:
    def __init__(
        self,
        mooncake_port,
        mooncake_metrics_port,
        eviction_high_watermark_ratio=0.8,
        eviction_ratio=0.05,
    ):
        self.mooncake_port = mooncake_port
        self.mooncake_metrics_port = mooncake_metrics_port
        self.eviction_high_watermark_ratio = eviction_high_watermark_ratio
        self.eviction_ratio = eviction_ratio

    def __enter__(self):
        cmd = [
            "mooncake_master",
            "--eviction_high_watermark_ratio",
            str(self.eviction_high_watermark_ratio),
            "--eviction_ratio",
            str(self.eviction_ratio),
            "--port",
            str(self.mooncake_port),
            "--metrics_port",
            str(self.mooncake_metrics_port),
        ]

        logger.info("Launching mooncake: %s", " ".join(cmd))
        curr_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
        mooncake_ld_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:"
        os.environ["LD_LIBRARY_PATH"] = mooncake_ld_path + curr_ld_path
        env = os.environ.copy()
        self.process = subprocess.Popen(cmd, env=env)
        return self

    def __exit__(self, exc_type, exc, tb):
        if not self.process:
            return
        logger.info("Stopping mooncake server...")
        self.process.terminate()
        try:
            self.process.wait(timeout=5)
        except subprocess.TimeoutExpired:
            self.process.kill()


class RemoteOpenAIServer:
    DUMMY_API_KEY = "token-abc123"  # vLLM's OpenAI server does not need API key

    def _start_server(self, model: str, server_cmd: list[str], env_dict: dict[str, str] | None) -> None:
        """Subclasses override this method to customize server process launch"""
        env = os.environ.copy()
        # the current process might initialize npu,
        # to be safe, we should use spawn method
        env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
        if env_dict is not None:
            env.update(env_dict)
        logger.info("Starting server with command: %s", " ".join(server_cmd))
        self.proc: subprocess.Popen = subprocess.Popen(
            server_cmd,
            env=env,
            stdout=sys.stdout,
            stderr=sys.stderr,
        )

    def __init__(
        self,
        model: str,
        vllm_serve_args: list[str] | str,
        *,
        server_host: str = "0.0.0.0",
        server_port: int = 8080,
        env_dict: dict[str, str] | None = None,
        seed: int | None = None,
        auto_port: bool = True,
        nodes_info: list[NodeInfo] | None = None,
        disaggregated_prefill: DisaggregatedPrefillCfg | None = None,
        proxy_port: int | None = None,
        max_wait_seconds: float | None = None,
        override_hf_configs: dict[str, Any] | None = None,
    ) -> None:
        if isinstance(vllm_serve_args, str):
            vllm_serve_args = shlex.split(vllm_serve_args)
        else:
            vllm_serve_args = ["vllm", "serve", model, *vllm_serve_args]
        if auto_port:
            if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
                raise ValueError("You have manually specified the port when `auto_port=True`.")

            # No need for a port if using unix sockets
            if "--uds" not in vllm_serve_args:
                # Don't mutate the input args
                vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
        if seed is not None:
            if "--seed" in vllm_serve_args:
                raise ValueError(f"You have manually specified the seed when `seed={seed}`.")

            vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]

        if override_hf_configs is not None:
            vllm_serve_args = vllm_serve_args + ["--hf-overrides", json.dumps(override_hf_configs)]

        self.host = str(server_host)
        self.port = int(server_port)
        # for multi-nodes test
        self.nodes_info = nodes_info
        self.disaggregated_prefill = disaggregated_prefill
        self.cur_index = os.getenv("LWS_WORKER_INDEX", 0)
        self.proxy_port = proxy_port

        self._start_server(model, vllm_serve_args, env_dict)
        max_wait_seconds = max_wait_seconds or 2800
        if self.disaggregated_prefill:
            assert proxy_port is not None, "for disaggregated_prefill, proxy port must be provided"
            self._wait_for_server_pd(timeout=max_wait_seconds)
        else:
            self._wait_for_multiple_servers([(self.host, self.url_for("health"))], timeout=max_wait_seconds)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._terminate_server()

    def _poll(self) -> int | None:
        """Subclasses override this method to customize process polling"""
        return self.proc.poll()

    def hang_until_terminated(self, url) -> None:
        """
        Wait until the server process terminates.
        This is for headless mode, where the api server
        process only exists in the leader node.
        """
        logger.info("Hanging until server process terminates...")
        client = requests
        try:
            while True:
                try:
                    resp = client.get(url, timeout=5)
                    if resp.status_code != 200:
                        break
                    time.sleep(5)
                except Exception:
                    break
        finally:
            self._terminate_server()

    def _wait_for_server_pd(self, timeout: float):
        # Wait for all api_server nodes ready
        assert self.nodes_info is not None, "cluster info must be provided"
        proxy_port = self.proxy_port

        def url_health(ip: str, port: int) -> str:
            return f"http://{ip}:{port}/health"

        targets = [
            (node_info.ip, url_health(node_info.ip, self.port))
            for node_info in self.nodes_info
            if not node_info.headless
        ]

        # Wait for proxy ready
        master_node = self.nodes_info[0]
        url_proxy = f"http://{master_node.ip}:{proxy_port}/healthcheck"

        # Wait for master node proxy first
        self._wait_for_multiple_servers([(master_node.ip, url_proxy)], timeout=timeout)

        # Then wait for all api_server nodes
        self._wait_for_multiple_servers(targets=targets, timeout=timeout)

    def _wait_for_multiple_servers(
        self, targets, timeout: float, log_interval: float = 30.0, always_check_nodes: bool = False
    ):
        """
        targets: List[(node_ip, url)]
        log_interval
        """
        start = time.time()
        client = requests

        ready = {node_ip: False for node_ip, _ in targets}

        last_log_time = 0.0

        while True:
            now = time.time()
            all_ready = True
            should_log = (now - last_log_time) >= log_interval

            for node_ip, url in targets:
                if ready[node_ip] and not always_check_nodes:
                    continue

                try:
                    resp = client.get(url)
                    if resp.status_code == 200:
                        ready[node_ip] = True
                        logger.info("[READY] Node %s: %s is ready.", node_ip, url)
                except RequestException:
                    all_ready = False
                    if should_log:
                        logger.debug("[WAIT] %s: connection failed", url)

                    # check unexpected exit
                    result = self._poll()
                    if result is not None and result != 0:
                        self._terminate_server()
                        raise RuntimeError(f"Server at {node_ip} exited unexpectedly.") from None

            if should_log:
                last_log_time = now

            if all_ready:
                break

            if now - start > timeout:
                not_ready_nodes = [n for n, ok in ready.items() if not ok]
                self._terminate_server()
                raise RuntimeError(
                    f"Timeout: these nodes did not become ready: {not_ready_nodes} in time: {timeout}s"
                ) from None

            time.sleep(5)

    @property
    def url_root(self) -> str:
        return f"http://{self.host}:{self.port}"

    def _terminate_server(self) -> None:
        """Subclasses override this method to customize server process termination"""
        self._terminate_process_tree(self.proc)

    def _terminate_process_tree(self, proc: subprocess.Popen) -> None:
        try:
            parent = psutil.Process(proc.pid)
        except psutil.NoSuchProcess:
            return

        children = parent.children(recursive=True)
        for child in children:
            with contextlib.suppress(psutil.NoSuchProcess):
                child.terminate()

        _, still_alive = psutil.wait_procs(children, timeout=10)

        for child in still_alive:
            with contextlib.suppress(psutil.NoSuchProcess):
                child.kill()

        try:
            parent.terminate()
            parent.wait(timeout=10)
        except (psutil.NoSuchProcess, psutil.TimeoutExpired):
            with contextlib.suppress(psutil.NoSuchProcess):
                parent.kill()

    def url_for(self, *parts: str) -> str:
        return self.url_root + "/" + "/".join(parts)

    def get_client(self, **kwargs):
        if "timeout" not in kwargs:
            kwargs["timeout"] = 600
        return openai.OpenAI(
            base_url=self.url_for("v1"),
            api_key=self.DUMMY_API_KEY,
            max_retries=0,
            **kwargs,
        )

    def get_async_client(self, **kwargs):
        if "timeout" not in kwargs:
            kwargs["timeout"] = 600
        return openai.AsyncOpenAI(base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs)


def _get_pd_server_required_devices(vllm_serve_args: list[str]) -> int:
    def get_size(arg_name: str) -> int:
        value = 1
        if arg_name in vllm_serve_args:
            value = int(vllm_serve_args[vllm_serve_args.index(arg_name) + 1])
        if value <= 0:
            raise ValueError(f"{arg_name} must be positive, got {value}.")
        return value

    tensor_parallel_size = get_size("--tensor-parallel-size")
    data_parallel_arg = (
        "--data-parallel-size-local" if "--data-parallel-size-local" in vllm_serve_args else "--data-parallel-size"
    )
    data_parallel_size = get_size(data_parallel_arg)
    return tensor_parallel_size * data_parallel_size


class RemotePDServer(RemoteOpenAIServer):
    def __init__(
        self,
        vllm_serve_args: list[str] | list[list[str]],
        server_host: str = "127.0.0.1",
        env_dict: dict[str, str] | None = None,
        max_wait_seconds: float | None = 600,
    ) -> None:
        self._proc_list = []

        self.env_dict: dict[str, str] = {}
        if env_dict is not None:
            self.env_dict.update(env_dict)

        self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
        self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
        self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        self.vllm_serve_args_list = []
        self.health_url_list = []
        self.host = server_host

        if isinstance(vllm_serve_args, list):
            if not all(isinstance(item, list) for item in vllm_serve_args):
                args_copy = copy.deepcopy(vllm_serve_args)
                self.vllm_serve_args_list = [[str(arg) for arg in args_copy]]
            else:
                self.vllm_serve_args_list = [
                    [str(arg) for arg in sublist] for sublist in copy.deepcopy(vllm_serve_args)
                ]
        else:
            raise RuntimeError("vllm_serves_args must be a list")
        serve_arg_cmd = ["vllm", "serve"]
        start_device_id = 0

        for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list):
            if "--port" not in vllm_serve_arg:
                raise ValueError("You have to manually specify the port")
            self.port = int(vllm_serve_arg[vllm_serve_arg.index("--port") + 1])
            self.health_url_list.append(self.url_for("health"))

            required_devices = _get_pd_server_required_devices(vllm_serve_arg)
            server_env = copy.deepcopy(self.env_dict)
            server_env["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(
                str(device_id) for device_id in range(start_device_id, start_device_id + required_devices)
            )
            start_device_id += required_devices

            vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg]
            proc = self._start_server_with_prefix(vllm_serve_arg, server_env, f"[PD_{i}] ")
            self._proc_list.append(proc)

        timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
        self._wait_for_multiple_servers(
            [(self.host, url) for url in self.health_url_list], timeout=timeout_value, always_check_nodes=True
        )

    def _poll(self) -> int | None:
        for proc in self._proc_list:
            result = proc.poll()
            if result is not None and result != 0:
                return result
        return None

    def _read_output(self, pipe, prefix):
        try:
            with pipe:
                for line in iter(pipe.readline, ""):
                    if line:
                        print(f"{prefix}: {line}", end="")

        except Exception as e:
            print(f"error: {e}")
            traceback.print_exc()

    def _start_server_with_prefix(self, server_cmd: list[str], env_dict: dict[str, str] | None, log_prefix: str):
        env = os.environ.copy()
        if env_dict is not None:
            env.update(env_dict)
        proc = subprocess.Popen(
            server_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, bufsize=1
        )
        stdout_thread = threading.Thread(target=self._read_output, args=(proc.stdout, log_prefix), daemon=True)
        stderr_thread = threading.Thread(target=self._read_output, args=(proc.stderr, log_prefix), daemon=True)

        stdout_thread.start()
        stderr_thread.start()
        return proc

    def _terminate_server(self) -> None:
        print("pd instance is stopping")
        for proc in self._proc_list:
            self._terminate_process_tree(proc)


class DisaggPDProxy(RemotePDServer):
    def __init__(
        self,
        port: int,
        prefiller_ports: list[int],
        decoder_ports: list[int],
        host: str = "127.0.0.1",
        env_dict: dict[str, str] | None = None,
        max_wait_seconds: float | None = 600,
    ) -> None:
        self.env_dict: dict[str, str] = {}
        if env_dict is not None:
            self.env_dict.update(env_dict)
        self._proc_list = []
        self.host = host
        self.port = int(port)
        self.proxy_args = [
            "--host",
            host,
            "--port",
            str(port),
            "--prefiller-hosts",
            *[host] * len(prefiller_ports),
            "--prefiller-ports",
            *[str(port) for port in prefiller_ports],
            "--decoder-hosts",
            *[host] * len(decoder_ports),
            "--decoder-ports",
            *[str(port) for port in decoder_ports],
        ]

        print(f"proxy param is: {self.proxy_args}")
        proxy_cmd = [sys.executable, str(DISAGG_PD_PROXY_SCRIPT), *self.proxy_args]
        proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PD_PROXY] ")
        self._proc_list.append(proc)

        timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 600.0
        self._wait_for_multiple_servers([(self.host, self.url_for("healthcheck"))], timeout=timeout_value)


class RemoteEPDServer(RemoteOpenAIServer):
    def _start_server(self, model: str, server_cmd: list[str], env_dict: dict[str, str] | None) -> None:
        """Subclasses override this method to customize server process launch"""
        raise NotImplementedError("RemoteEPDServer should use _start_server_with_prefix instead")

    def __init__(
        self,
        vllm_serve_args: list[str] | list[list[str]],
        server_host: str = "0.0.0.0",
        env_dict: dict[str, str] | None = None,
        max_wait_seconds: float | None = 2800,
    ) -> None:
        self._proc_list = []

        self.env_dict: dict[str, str] = {}
        if env_dict is not None:
            self.env_dict.update(env_dict)

        self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
        self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
        self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        self.vllm_serve_args_list = []
        self.health_url_list = []
        self.host = server_host

        if isinstance(vllm_serve_args, list):
            if not all(isinstance(item, list) for item in vllm_serve_args):
                args_copy = copy.deepcopy(vllm_serve_args)
                self.vllm_serve_args_list.append([str(arg) for arg in args_copy])
            else:
                self.vllm_serve_args_list = [
                    [str(arg) for arg in sublist] for sublist in copy.deepcopy(vllm_serve_args)
                ]
        else:
            raise RuntimeError("vllm_serves_args must be a list")

        serve_arg_cmd = ["vllm", "serve"]

        for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list):
            self.env_dict["ASCEND_RT_VISIBLE_DEVICES"] = str(i)
            if isinstance(vllm_serve_arg, list):
                if "--port" not in vllm_serve_arg:
                    raise ValueError("You have to manually specify the port")
                else:
                    port_arg = "--port"
                    try:
                        index = vllm_serve_arg.index(port_arg)
                    except ValueError:
                        raise ValueError(f"--port not found in args: {vllm_serve_arg}")
                    port_str = vllm_serve_arg[index + 1]
                    self.port = int(port_str)
            else:
                vllm_serve_arg_str = str(vllm_serve_arg)
                if "--port" not in vllm_serve_arg_str:
                    raise ValueError("You have to manually specify the port")
                else:
                    raise ValueError(f"Unexpected type for vllm_serve_arg: {type(vllm_serve_arg)}")

            self.health_url_list.append(super().url_for("health"))
            vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg]
            proc = self._start_server_with_prefix(vllm_serve_arg, self.env_dict, f"[VLLM_{i}] ")
            self._proc_list.append(proc)

        timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
        super()._wait_for_multiple_servers(
            [(self.host, url) for url in self.health_url_list], timeout=timeout_value, always_check_nodes=True
        )

    def _poll(self) -> int | None:
        return None

    def _delete_shm(self) -> None:
        for i, arg in enumerate(self.vllm_serve_args_list):
            if "--ec-transfer-config" in arg:
                index = arg.index("--ec-transfer-config")
                config_str = arg[index + 1]
                config_dict = json.loads(config_str)
                ec_connector_extra_config = config_dict.get("ec_connector_extra_config", {})
                shm_path = ec_connector_extra_config.get("shared_storage_path")
                if shm_path:
                    args = ["rm", "-r", "-f", str(shm_path)]
                    print(f"delete shm_path is: {shm_path}")
                    self._start_server_with_prefix(args, None, "[DELETE] ")

    def _read_output(self, pipe, prefix):
        try:
            with pipe:
                for line in iter(pipe.readline, ""):
                    if line:
                        print(f"{prefix}: {line}", end="")

        except Exception as e:
            print(f"error: {e}")
            traceback.print_exc()

    def _start_server_with_prefix(self, server_cmd: list[str], env_dict: dict[str, str] | None, log_prefix: str):
        env = os.environ.copy()
        if env_dict is not None:
            env.update(env_dict)
        proc = subprocess.Popen(
            server_cmd,
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True,
            bufsize=1,
        )
        stdout_thread = threading.Thread(target=self._read_output, args=(proc.stdout, log_prefix), daemon=True)
        stderr_thread = threading.Thread(target=self._read_output, args=(proc.stderr, log_prefix), daemon=True)

        stdout_thread.start()
        stderr_thread.start()
        return proc

    def _terminate_server(self) -> None:
        """Kill server processes and their children."""
        print("vllm instance is stopping")
        for proc in self._proc_list:
            self._terminate_process_tree(proc)

    def __enter__(self):
        """Context manager entry point."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit point - clean up all processes."""
        self._terminate_server()


class DisaggEpdProxy(RemoteEPDServer):
    def __init__(
        self,
        proxy_args: list[str] | str | None = None,
        env_dict: dict[str, str] | None = None,
        server_host: str = "0.0.0.0",
        max_wait_seconds: float | None = 2800,
    ) -> None:
        if proxy_args is None:
            proxy_args_list: list[str] = []
        elif isinstance(proxy_args, str):
            proxy_args_list = shlex.split(proxy_args)
        else:
            proxy_args_list = proxy_args

        self.proxy_args = proxy_args_list
        self.env_dict: dict[str, str] = {}
        if env_dict is not None:
            self.env_dict.update(env_dict)
        self._proc_list = list()
        self.host = server_host

        print(f"proxy param is: {self.proxy_args}")
        proxy_cmd = ["python", str(DISAGG_EPD_PROXY_SCRIPT), *self.proxy_args]
        proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PROXY] ")
        self._proc_list.append(proc)

        if "--port" not in self.proxy_args:
            raise ValueError("You have manually specified the port ")
        else:
            try:
                index = self.proxy_args.index("--port")
            except ValueError:
                raise ValueError("--port not found in proxy args")
            port_str = self.proxy_args[index + 1]
            self.port = int(port_str)

        timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
        super()._wait_for_multiple_servers([(self.host, super().url_for("health"))], timeout=timeout_value)

    def __enter__(self):
        """Context manager entry point."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit point - clean up all processes."""
        super()._terminate_server()


_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0
_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0
_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0


def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]:
    if num_items < 0:
        raise ValueError("num_items must be non-negative")
    if dp_size <= 0:
        raise ValueError("dp_size must be positive")

    floor = num_items // dp_size
    remainder = num_items % dp_size

    def start(rank: int) -> int:
        return rank * floor + min(rank, remainder)

    return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)]


def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]):
    if inputs is None:
        return None
    return [inputs[index] for index in indices]


def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]:
    return [items[index] for index in indices]


def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]:
    merged: list[Any] = [None] * total_items
    for indices, results in shard_results:
        if not indices:
            continue
        if len(indices) != len(results):
            raise RuntimeError("Mismatched result count returned by data parallel worker")
        for index, result in zip(indices, results):
            merged[index] = result

    if any(result is None for result in merged):
        raise RuntimeError("Some data parallel results were not returned")

    return merged


def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]:
    if isinstance(text_1, str) and isinstance(text_2, str):
        return [text_1], [text_2]
    if isinstance(text_1, str):
        return [text_1] * len(text_2), list(text_2)
    if isinstance(text_2, str):
        return list(text_1), [text_2] * len(text_1)
    if len(text_1) != len(text_2):
        raise ValueError("`text_1` and `text_2` must have the same length")
    return list(text_1), list(text_2)


def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None:
    llm = None
    try:
        os.environ["VLLM_DP_RANK"] = str(dp_rank)
        os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
        os.environ["VLLM_DP_SIZE"] = str(dp_size)
        os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1"
        os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        llm = LLM(**llm_kwargs)
        conn.send({"status": "ready", "rank": dp_rank})

        while True:
            request = conn.recv()
            command = request["command"]
            if command == "shutdown":
                break

            result: Any
            if command == "generate":
                req_outputs = llm.generate(
                    request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
                )
                result = VllmRunner._finalize_generate_outputs(req_outputs)
            elif command == "generate_w_logprobs":
                req_outputs = llm.generate(
                    request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
                )
                result = VllmRunner._final_steps_generate_w_logprobs(req_outputs)
            elif command == "classify":
                req_outputs = llm.classify(request["prompts"])
                result = [req_output.outputs.probs for req_output in req_outputs]
            elif command == "embed":
                req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"])
                result = [req_output.outputs.embedding for req_output in req_outputs]
            elif command == "encode":
                req_outputs = llm.encode(request["prompts"])
                result = [req_output.outputs.data for req_output in req_outputs]
            elif command == "reward":
                req_outputs = llm.reward(request["prompts"])
                result = [req_output.outputs.data for req_output in req_outputs]
            elif command == "score":
                req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"])
                result = [req_output.outputs.score for req_output in req_outputs]
            else:
                raise ValueError(f"Unsupported data parallel command: {command}")

            conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result})
    except Exception:
        with contextlib.suppress(Exception):
            conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()})
        raise
    finally:
        if llm is not None:
            del llm
        clear_ascend_config()
        cleanup_dist_env_and_memory()
        with contextlib.suppress(Exception):
            conn.close()


class VllmRunner:
    def __init__(
        self,
        model_name: str,
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
        tokenizer_name: str | None = None,
        tokenizer_mode: str = "auto",
        max_model_len: int | None = 1024,
        dtype: str = "auto",
        disable_log_stats: bool = True,
        tensor_parallel_size: int = 1,
        block_size: int = 16,
        enable_chunked_prefill: bool = True,
        swap_space: int = 4,
        enforce_eager: bool | None = False,
        quantization: str | None = None,
        **kwargs,
    ) -> None:
        data_parallel_size = int(kwargs.get("data_parallel_size", 1))
        if data_parallel_size > 1:
            raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.")

        self.model = LLM(
            model=model_name,
            runner=runner,
            convert=convert,
            tokenizer=tokenizer_name,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=True,
            dtype=dtype,
            swap_space=swap_space,
            enforce_eager=enforce_eager,
            disable_log_stats=disable_log_stats,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=max_model_len,
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
            quantization=quantization,
            **kwargs,
        )

    @staticmethod
    def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]:
        outputs: list[tuple[list[list[int]], list[str]]] = []
        for req_output in req_outputs:
            prompt_str = req_output.prompt
            prompt_ids = req_output.prompt_token_ids
            req_sample_output_ids: list[list[int]] = []
            req_sample_output_strs: list[str] = []
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = list(sample.token_ids)
                req_sample_output_ids.append(prompt_ids + output_ids)
                req_sample_output_strs.append((prompt_str or "") + output_str)
            outputs.append((req_sample_output_ids, req_sample_output_strs))
        return outputs

    def get_inputs(
        self,
        prompts: list[str] | list[torch.Tensor] | list[int],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
    ) -> list[TextPrompt]:
        if any(x is not None and len(x) != len(prompts) for x in [images, videos, audios]):
            raise ValueError("All non-None multimodal inputs must have the same length as prompts")

        inputs = []
        for i, prompt in enumerate(prompts):
            multi_modal_data = {}
            if images is not None and (image := images[i]) is not None:
                multi_modal_data["image"] = image
            if videos is not None and (video := videos[i]) is not None:
                multi_modal_data["video"] = video  # type: ignore
            if audios is not None and (audio := audios[i]) is not None:
                multi_modal_data["audio"] = audio  # type: ignore

            text_prompt_kwargs: dict[str, Any] = {"multi_modal_data": multi_modal_data or None}
            if isinstance(prompt, str):
                text_prompt_kwargs["prompt"] = prompt
            elif isinstance(prompt, list):
                text_prompt_kwargs["prompt_token_ids"] = prompt
            else:
                text_prompt_kwargs["prompt_embeds"] = prompt

            inputs.append(TextPrompt(**text_prompt_kwargs))

        return inputs

    def generate(
        self,
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
        sampling_params: SamplingParams,
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        **kwargs: Any,
    ) -> list[tuple[list[list[int]], list[str]]]:
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
        req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
        return self._finalize_generate_outputs(req_outputs)

    @staticmethod
    def _final_steps_generate_w_logprobs(
        req_outputs: list[RequestOutput],
    ) -> list[TokensTextLogprobsPromptLogprobs]:
        outputs: list[TokensTextLogprobsPromptLogprobs] = []
        for req_output in req_outputs:
            assert len(req_output.outputs) > 0
            for sample in req_output.outputs:
                output_str = sample.text
                output_ids = list(sample.token_ids)
                output_logprobs = sample.logprobs
            outputs.append((output_ids, output_str, output_logprobs, req_output.prompt_logprobs))
        return outputs

    def generate_w_logprobs(
        self,
        prompts: list[str],
        sampling_params: SamplingParams,
        images: PromptImageInput | None = None,
        audios: PromptAudioInput | None = None,
        videos: PromptVideoInput | None = None,
        **kwargs: Any,
    ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)

        req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)

        toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(req_outputs)
        # Omit prompt logprobs if not required by sampling params
        return (
            [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
            if sampling_params.prompt_logprobs is None
            else toks_str_logsprobs_prompt_logprobs
        )

    def generate_greedy(
        self,
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
        max_tokens: int,
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        **kwargs: Any,
    ) -> list[tuple[list[int], str]]:
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
        outputs = self.generate(prompts, greedy_params, images=images, videos=videos, audios=audios, **kwargs)
        return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]

    def generate_greedy_logprobs(
        self,
        prompts: list[str],
        max_tokens: int,
        num_logprobs: int | None,
        num_prompt_logprobs: int | None = None,
        images: PromptImageInput | None = None,
        audios: PromptAudioInput | None = None,
        videos: PromptVideoInput | None = None,
        stop_token_ids: list[int] | None = None,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
        greedy_logprobs_params = SamplingParams(
            temperature=0.0,
            max_tokens=max_tokens,
            logprobs=num_logprobs,
            prompt_logprobs=num_prompt_logprobs,
            stop_token_ids=stop_token_ids,
            stop=stop,
        )

        return self.generate_w_logprobs(
            prompts, greedy_logprobs_params, images=images, audios=audios, videos=videos, **kwargs
        )

    def classify(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.model.classify(prompts)
        return [req_output.outputs.probs for req_output in req_outputs]

    def embed(
        self,
        prompts: list[str],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        *args,
        **kwargs,
    ) -> list[list[float]]:
        inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)

        req_outputs = self.model.embed(inputs, *args, **kwargs)
        return [req_output.outputs.embedding for req_output in req_outputs]

    def encode(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.model.encode(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

    def reward(self, prompts: list[str]) -> list[list[float]]:
        req_outputs = self.model.reward(prompts)
        return [req_output.outputs.data for req_output in req_outputs]

    def score(
        self,
        text_1: str | list[str],
        text_2: str | list[str],
        *args,
        **kwargs,
    ) -> list[float]:
        req_outputs = self.model.score(text_1, text_2, *args, **kwargs)
        return [req_output.outputs.score for req_output in req_outputs]

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        del self.model
        clear_ascend_config()
        cleanup_dist_env_and_memory()


class DPVllmRunner(VllmRunner):
    def __init__(
        self,
        model_name: str,
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
        tokenizer_name: str | None = None,
        tokenizer_mode: str = "auto",
        max_model_len: int | None = 1024,
        dtype: str = "auto",
        disable_log_stats: bool = True,
        tensor_parallel_size: int = 1,
        block_size: int = 16,
        enable_chunked_prefill: bool = True,
        swap_space: int = 4,
        enforce_eager: bool | None = False,
        quantization: str | None = None,
        data_parallel_size: int = 2,
        **kwargs,
    ) -> None:
        if data_parallel_size < 2:
            raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`")

        self._dp_size = data_parallel_size
        self._dp_parent_conns: list[Any] = []
        self._dp_processes: list[Any] = []
        self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS))
        self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS))

        llm_kwargs = dict(
            model=model_name,
            runner=runner,
            convert=convert,
            tokenizer=tokenizer_name,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=True,
            dtype=dtype,
            swap_space=swap_space,
            enforce_eager=enforce_eager,
            disable_log_stats=disable_log_stats,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=max_model_len,
            block_size=block_size,
            enable_chunked_prefill=enable_chunked_prefill,
            quantization=quantization,
            **kwargs,
        )

        cleanup_dist_env_and_memory()
        self._start_data_parallel_workers(llm_kwargs)

    @property
    def model(self) -> LLM:
        raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.")

    def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None:
        ctx = multiprocessing.get_context("spawn")
        master_port = get_open_port()

        try:
            for dp_rank in range(self._dp_size):
                parent_conn, child_conn = ctx.Pipe()
                proc = ctx.Process(
                    target=_run_vllm_runner_dp_worker,
                    args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port),
                )
                proc.start()
                child_conn.close()
                self._dp_parent_conns.append(parent_conn)
                self._dp_processes.append(proc)

            for rank, conn in enumerate(self._dp_parent_conns):
                if not conn.poll(self._dp_start_timeout):
                    raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start")
                message = conn.recv()
                if message["status"] != "ready":
                    raise RuntimeError(
                        f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}"
                    )
        except Exception:
            self._stop_data_parallel_workers()
            raise

    def _stop_data_parallel_workers(self) -> None:
        for conn in self._dp_parent_conns:
            with contextlib.suppress(Exception):
                conn.send({"command": "shutdown"})

        for proc in self._dp_processes:
            proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS)
            if proc.is_alive():
                proc.kill()
                proc.join(timeout=5)

        for conn in self._dp_parent_conns:
            with contextlib.suppress(Exception):
                conn.close()

        self._dp_parent_conns.clear()
        self._dp_processes.clear()

    def _dispatch_prompt_command(
        self,
        command: str,
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
        *,
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        **payload: Any,
    ) -> list[Any]:
        if not prompts:
            return []

        shard_results: list[tuple[list[int], list[Any]]] = []
        shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)

        for rank, conn in enumerate(self._dp_parent_conns):
            indices = shard_indices[rank]
            worker_indices = indices or [0]
            worker_prompts = _slice_list_inputs(prompts, worker_indices)
            conn.send(
                {
                    "command": command,
                    "indices": indices,
                    "inputs": self.get_inputs(
                        worker_prompts,
                        images=_slice_optional_inputs(images, worker_indices),
                        videos=_slice_optional_inputs(videos, worker_indices),
                        audios=_slice_optional_inputs(audios, worker_indices),
                    ),
                    "prompts": worker_prompts,
                    **payload,
                }
            )

        try:
            for rank, conn in enumerate(self._dp_parent_conns):
                if not conn.poll(self._dp_request_timeout):
                    raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
                message = conn.recv()
                if message["status"] != "ok":
                    raise RuntimeError(
                        f"Data parallel worker {rank} failed during `{command}`:\n"
                        f"{message.get('traceback', 'unknown error')}"
                    )
                shard_results.append((message["indices"], message["result"]))
        except Exception:
            self._stop_data_parallel_workers()
            raise

        return _merge_data_parallel_results(len(prompts), shard_results)

    def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]:
        if not prompts:
            return []

        shard_results: list[tuple[list[int], list[Any]]] = []
        shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)

        for rank, conn in enumerate(self._dp_parent_conns):
            indices = shard_indices[rank]
            worker_indices = indices or [0]
            conn.send(
                {
                    "command": command,
                    "indices": indices,
                    "prompts": _slice_list_inputs(prompts, worker_indices),
                }
            )

        try:
            for rank, conn in enumerate(self._dp_parent_conns):
                if not conn.poll(self._dp_request_timeout):
                    raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
                message = conn.recv()
                if message["status"] != "ok":
                    raise RuntimeError(
                        f"Data parallel worker {rank} failed during `{command}`:\n"
                        f"{message.get('traceback', 'unknown error')}"
                    )
                shard_results.append((message["indices"], message["result"]))
        except Exception:
            self._stop_data_parallel_workers()
            raise

        return _merge_data_parallel_results(len(prompts), shard_results)

    def generate(
        self,
        prompts: list[str] | list[torch.Tensor] | list[list[int]],
        sampling_params: SamplingParams,
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        **kwargs: Any,
    ) -> list[tuple[list[list[int]], list[str]]]:
        return self._dispatch_prompt_command(
            "generate",
            prompts,
            images=images,
            videos=videos,
            audios=audios,
            sampling_params=sampling_params,
            kwargs=kwargs,
        )

    def generate_w_logprobs(
        self,
        prompts: list[str],
        sampling_params: SamplingParams,
        images: PromptImageInput | None = None,
        audios: PromptAudioInput | None = None,
        videos: PromptVideoInput | None = None,
        **kwargs: Any,
    ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
        toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command(
            "generate_w_logprobs",
            prompts,
            images=images,
            videos=videos,
            audios=audios,
            sampling_params=sampling_params,
            kwargs=kwargs,
        )
        return (
            [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
            if sampling_params.prompt_logprobs is None
            else toks_str_logsprobs_prompt_logprobs
        )

    def classify(self, prompts: list[str]) -> list[list[float]]:
        return self._dispatch_text_command("classify", prompts)

    def embed(
        self,
        prompts: list[str],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
        *args,
        **kwargs,
    ) -> list[list[float]]:
        return self._dispatch_prompt_command(
            "embed",
            prompts,
            images=images,
            videos=videos,
            audios=audios,
            args=args,
            kwargs=kwargs,
        )

    def encode(self, prompts: list[str]) -> list[list[float]]:
        return self._dispatch_text_command("encode", prompts)

    def reward(self, prompts: list[str]) -> list[list[float]]:
        return self._dispatch_text_command("reward", prompts)

    def score(
        self,
        text_1: str | list[str],
        text_2: str | list[str],
        *args,
        **kwargs,
    ) -> list[float]:
        normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2)
        if not normalized_text_1:
            return []

        shard_results: list[tuple[list[int], list[Any]]] = []
        shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size)

        for rank, conn in enumerate(self._dp_parent_conns):
            indices = shard_indices[rank]
            worker_indices = indices or [0]
            conn.send(
                {
                    "command": "score",
                    "indices": indices,
                    "text_1": _slice_list_inputs(normalized_text_1, worker_indices),
                    "text_2": _slice_list_inputs(normalized_text_2, worker_indices),
                    "args": args,
                    "kwargs": kwargs,
                }
            )

        try:
            for rank, conn in enumerate(self._dp_parent_conns):
                if not conn.poll(self._dp_request_timeout):
                    raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`")
                message = conn.recv()
                if message["status"] != "ok":
                    raise RuntimeError(
                        f"Data parallel worker {rank} failed during `score`:\n"
                        f"{message.get('traceback', 'unknown error')}"
                    )
                shard_results.append((message["indices"], message["result"]))
        except Exception:
            self._stop_data_parallel_workers()
            raise

        return _merge_data_parallel_results(len(normalized_text_1), shard_results)

    def __exit__(self, exc_type, exc_value, traceback):
        self._stop_data_parallel_workers()
        clear_ascend_config()
        cleanup_dist_env_and_memory()


DataParallelVllmRunner = DPVllmRunner


class HfRunner:
    def get_default_device(self):
        return "cpu" if current_platform.is_cpu() else current_platform.device_type

    def wrap_device(self, x: _T, device: str | None = None) -> _T:
        if x is None or isinstance(x, (bool,)):
            return x

        if device is None:
            device = self.device

        if isinstance(x, dict):
            return {k: self.wrap_device(v, device) for k, v in x.items()}

        if hasattr(x, "device") and x.device.type == device:
            return x

        return x.to(device)

    def __init__(
        self,
        model_name: str,
        dtype: str = "auto",
        *,
        model_kwargs: dict[str, Any] | None = None,
        trust_remote_code: bool = True,
        is_sentence_transformer: bool = False,
        is_cross_encoder: bool = False,
        skip_tokenizer_init: bool = False,
        auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
    ) -> None:
        model_name = maybe_model_redirect(model_name)
        self.model_name = model_name

        self.config = AutoConfig.from_pretrained(
            model_name,
            trust_remote_code=trust_remote_code,
        )
        self.device = self.get_default_device()
        self.dtype = torch_dtype = _get_and_verify_dtype(
            self.model_name,
            self.config,
            dtype=dtype,
            is_pooling_model=is_sentence_transformer or is_cross_encoder,
        )

        model_kwargs = model_kwargs if model_kwargs is not None else {}
        model_kwargs.setdefault("torch_dtype", torch_dtype)

        if is_sentence_transformer:
            # Lazy init required for AMD CI
            from sentence_transformers import SentenceTransformer

            self.model = SentenceTransformer(
                model_name,
                device=self.device,
                model_kwargs=model_kwargs,
                trust_remote_code=trust_remote_code,
            )
        elif is_cross_encoder:
            # Lazy init required for AMD CI
            from sentence_transformers import CrossEncoder

            self.model = CrossEncoder(
                model_name,
                device=self.device,
                automodel_args=model_kwargs,
                trust_remote_code=trust_remote_code,
            )
        else:
            model = auto_cls.from_pretrained(
                model_name,
                trust_remote_code=trust_remote_code,
                **model_kwargs,
            )

            # in case some unquantized custom models are not in same dtype
            if getattr(model, "quantization_method", None) is None and any(
                p.dtype != self.dtype for p in model.parameters()
            ):
                model = model.to(dtype=self.dtype)

            if (
                getattr(model, "quantization_method", None) != "bitsandbytes"
                and len({p.device for p in model.parameters()}) < 2
            ):
                model = model.to(device=self.device)

            self.model = model

        if not skip_tokenizer_init:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                trust_remote_code=trust_remote_code,
            )

        # don't put this import at the top level
        # it will call torch.cuda.device_count()
        from transformers import AutoProcessor  # noqa: F401

        self.processor = AutoProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,
        )
        if skip_tokenizer_init:
            self.tokenizer = self.processor.tokenizer

    def get_inputs(
        self,
        prompts: list[str],
        images: PromptImageInput | None = None,
        videos: PromptVideoInput | None = None,
        audios: PromptAudioInput | None = None,
    ) -> list[BatchFeature | BatchEncoding]:
        if images is not None:
            assert len(prompts) == len(images)

        if videos is not None:
            assert len(prompts) == len(videos)

        if audios is not None:
            assert len(prompts) == len(audios)

        all_inputs: list[BatchFeature | BatchEncoding] = []
        for i, prompt in enumerate(prompts):
            processor_kwargs: dict[str, Any] = {
                "text": prompt,
                "return_tensors": "pt",
            }
            if images is not None and (image := images[i]) is not None:
                processor_kwargs["images"] = image
            if videos is not None and (video := videos[i]) is not None:
                processor_kwargs["videos"] = video
            if audios is not None and (audio_inputs := audios[i]) is not None:
                # HACK - not all processors take sampling_rate; we should
                # clean this up in the future.
                if len(audio_inputs) == 2:
                    audio, sr = audio_inputs
                    processor_kwargs["audio"] = audio
                    processor_kwargs["sampling_rate"] = sr
                else:
                    processor_kwargs["audio"] = audio_inputs

            inputs = self.processor(**processor_kwargs)
            if isinstance(inputs, BatchFeature):
                inputs = inputs.to(dtype=self.dtype)

            all_inputs.append(inputs)

        return all_inputs

    def classify(self, prompts: list[str]) -> list[str]:
        # output is final logits
        all_inputs = self.get_inputs(prompts)
        outputs = []
        problem_type = getattr(self.config, "problem_type", "")

        for inputs in all_inputs:
            output = self.model(**self.wrap_device(inputs))
            if problem_type == "regression":
                logits = output.logits[0].tolist()
            elif problem_type == "multi_label_classification":
                logits = output.logits.sigmoid()[0].tolist()
            else:
                logits = output.logits.softmax(dim=-1)[0].tolist()
            outputs.append(logits)

        return outputs

    def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
        return self.model.encode(prompts, *args, **kwargs)

    def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
        return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        del self.model
        cleanup_dist_env_and_memory()


@pytest.fixture(scope="session")
def ilama_lora_files():
    return snapshot_download(
        repo_id="vllm-ascend/ilama-text2sql-spider",
        local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
    )


@pytest.fixture(scope="session")
def llama32_lora_files():
    from huggingface_hub import snapshot_download as hf_snapshot_download

    return hf_snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider", local_files_only=True)


@pytest.fixture(scope="session")
def qwen35_text_lora_files():
    return snapshot_download(repo_id="vllm-ascend/qwen35-4b-text-only-sql-lora")


def qwen_prompt(questions: list[str]) -> list[str]:
    placeholder = "<|image_pad|>"
    return [
        (
            "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
            f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
            f"{q}<|im_end|>\n<|im_start|>assistant\n"
        )
        for q in questions
    ]


def hunyuan_prompt(questions: list[str]) -> list[str]:
    placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>"  # noqa: E501
    return [f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>" for question in questions]


PROMPT_CONFIGS = {
    "qwen-vl": {
        "model": "Qwen/Qwen3-VL-8B-Instruct",
        "prompt_fn": qwen_prompt,
        "mm_processor_kwargs": {
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1,
        },
    },
    "hunyuan-vl": {
        "model": "Tencent-Hunyuan/HunyuanOCR",
        "prompt_fn": hunyuan_prompt,
        "mm_processor_kwargs": {},
    },
}


@pytest.fixture(params=PROMPT_CONFIGS.keys())
def vl_config(request):
    config = PROMPT_CONFIGS[request.param]
    if "skip" in config:
        pytest.skip(config["skip"])
    return config