import unittest

import pytest
import torch
from tensor_cast.core.input_generator import generate_inputs
from tensor_cast.core.model_runner import ModelRunner, ModelRunnerMetrics
from tensor_cast.core.quantization.datatypes import QuantizeLinearAction
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.patch_torch import patch_torch


class TestMaskedScatterMetaSafe(unittest.TestCase):
    def test_masked_scatter_on_meta_is_shape_safe(self):
        with patch_torch():
            x = torch.empty((4, 8), dtype=torch.float16, device="meta")
            mask = torch.zeros((4, 8), dtype=torch.bool, device="meta")
            src = torch.empty((0,), dtype=torch.float16, device="meta")
            out = x.masked_scatter(mask, src)
            self.assertEqual(tuple(out.shape), (4, 8))
            self.assertEqual(out.dtype, torch.float16)
            self.assertEqual(out.device.type, "meta")


@pytest.mark.nightly
class TestVLCompilePrefillNightly(unittest.TestCase):
    def test_glm45v_prefill_with_compile(self):
        user_input = UserInputConfig(
            device="TEST_DEVICE",
            model_id="zai-org/GLM-4.5V",
            num_queries=1,
            query_len=30,
            context_length=0,
            image_batch_size=1,
            image_height=1080,
            image_width=1920,
            do_compile=True,
            allow_graph_break=False,
            quantize_linear_action=QuantizeLinearAction.DISABLED,
        )
        model_runner = ModelRunner(user_input)
        self.assertTrue(model_runner.model.is_vl_model)
        result = model_runner.run_inference(generate_inputs_func=generate_inputs)
        if isinstance(result, ModelRunnerMetrics):
            exec_time = result.execution_time_s
            if isinstance(exec_time, dict):
                exec_time = next(iter(exec_time.values()))
            self.assertGreaterEqual(exec_time, 0.0)
            self.assertIn("Total time for analytic", result.table_result)


if __name__ == "__main__":
    unittest.main()