#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.

Run `pytest tests/test_offline_inference.py`.
"""

import os
from unittest.mock import patch

from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset

from tests.e2e.conftest import VllmRunner


@patch.dict(os.environ, {"VLLM_WORKER_MULTIPROC_METHOD": "spawn"})
def test_multimodal_vl(vl_config):
    image = ImageAsset("cherry_blossom").pil_image.convert("RGB")

    img_questions = [
        "What is the content of this image?",
        "Describe the content of this image in detail.",
        "What's in the image?",
        "Where is this image taken?",
    ]

    images = [image] * len(img_questions)
    prompts = vl_config["prompt_fn"](img_questions)

    with VllmRunner(
        vl_config["model"],
        mm_processor_kwargs=vl_config["mm_processor_kwargs"],
        max_model_len=8192,
        cudagraph_capture_sizes=[1, 2, 4, 8],
        limit_mm_per_prompt={"image": 1},
    ) as vllm_model:
        outputs = vllm_model.generate_greedy(
            prompts=prompts,
            images=images,
            max_tokens=64,
        )

        assert len(outputs) == len(prompts)

        for _, output_str in outputs:
            assert output_str, "Generated output should not be empty."


@patch.dict(os.environ, {"VLLM_WORKER_MULTIPROC_METHOD": "spawn"})
def test_multimodal_vl_language_model_only():
    example_prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    max_tokens = 5
    with VllmRunner(
        "Qwen/Qwen3-VL-8B-Instruct",
        max_model_len=4096,
        cudagraph_capture_sizes=[1, 2, 4, 8],
        gpu_memory_utilization=0.90,
        language_model_only=True,
    ) as vllm_model:
        vllm_model.generate_greedy(example_prompts, max_tokens)


@patch.dict(os.environ, {"VLLM_WORKER_MULTIPROC_METHOD": "spawn"})
def test_multimodal_audio():
    audio_prompt = "".join([f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(2)])
    question = "What sport and what nursery rhyme are referenced?"
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
    mm_data = {
        "audio": [asset.audio_and_sample_rate for asset in [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]]
    }
    inputs = {"prompt": prompt, "multi_modal_data": mm_data}

    sampling_params = SamplingParams(temperature=0.2, max_tokens=10, stop_token_ids=None)

    with VllmRunner(
        "Qwen/Qwen2-Audio-7B-Instruct",
        max_model_len=4096,
        max_num_seqs=5,
        dtype="bfloat16",
        limit_mm_per_prompt={"audio": 2},
        cudagraph_capture_sizes=[1, 2, 4, 8],
        gpu_memory_utilization=0.9,
    ) as runner:
        outputs = runner.generate(inputs, sampling_params=sampling_params)

        assert outputs is not None, "Generated outputs should not be None."
        assert len(outputs) > 0, "Generated outputs should not be empty."