"""Unit tests for tensor_cast.core.user_config."""

from dataclasses import fields

import pytest

from tensor_cast.core.quantization.datatypes import QuantizeAttentionAction, QuantizeLinearAction
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.model_config import WordEmbeddingTPMode


class TestUserInputConfigPrintInfo:
    def test_print_info_reports_mxfp4_and_backbone_override(self, capsys):
        user_config = UserInputConfig(
            model_id="deepseek-ai/DeepSeek-V4",
            quantize_linear_action=QuantizeLinearAction.MXFP4,
            quantize_backbone_linear_action=QuantizeLinearAction.FP8,
            mxfp4_group_size=32,
            quantize_attention_action=QuantizeAttentionAction.FP8,
        )

        user_config._print_info()
        output = capsys.readouterr().out

        assert "Quantization Linear: MXFP4" in output
        assert "MXFP4 group size: 32" in output
        assert "Quantization Backbone Linear (override): FP8" in output
        assert "Quantization Attention: FP8" in output

    def test_print_info_reports_disabled_quantization(self, capsys):
        user_config = UserInputConfig(
            model_id="test/model",
            quantize_linear_action=QuantizeLinearAction.DISABLED,
            quantize_backbone_linear_action=QuantizeLinearAction.DISABLED,
            quantize_attention_action=QuantizeAttentionAction.DISABLED,
        )

        user_config._print_info()
        output = capsys.readouterr().out

        assert "Quantization Linear: Disabled" in output
        assert "Quantization Attention: Disabled" in output
        assert "Quantization Backbone Linear" not in output

    def test_get_quant_config_mxfp4_experts_fp8_backbone(self):
        user_config = UserInputConfig(
            quantize_linear_action=QuantizeLinearAction.MXFP4,
            quantize_backbone_linear_action=QuantizeLinearAction.FP8,
            mxfp4_group_size=32,
        )

        quant_config = user_config.get_quant_config()

        from tensor_cast.quantize_utils import LinearQuantType, QuantGranularity, get_quant_config

        backbone_cfg = get_quant_config("model.layers.0.self_attn.q_proj", quant_config, "default_dit")
        expert_cfg = get_quant_config("model.layers.0.mlp.experts.1.up_proj", quant_config, "default_dit")

        assert backbone_cfg.quant_type == LinearQuantType.FP8
        assert expert_cfg.quant_type == LinearQuantType.MXFP4
        assert expert_cfg.weight_group_size == 32
        assert expert_cfg.weight_quant_granularity == QuantGranularity.PER_GROUP


class TestUserInputConfigWordEmbeddingTp:
    def test_word_embedding_tp_is_single_nullable_mode(self):
        user_config = UserInputConfig(word_embedding_tp="row")

        assert user_config.word_embedding_tp == WordEmbeddingTPMode.row
        removed_user_field = "word_embedding_tp" + "_mode"
        removed_parallel_field = "embedding_parallel" + "_mode"
        assert removed_user_field not in {field.name for field in fields(UserInputConfig)}

        parallel_config = user_config.get_parallel_config()
        assert parallel_config.embedding_parallel == WordEmbeddingTPMode.row
        assert removed_parallel_field not in {field.name for field in fields(type(parallel_config))}

    def test_legacy_bool_word_embedding_tp_is_still_normalized(self):
        enabled_config = UserInputConfig(word_embedding_tp=True)
        disabled_config = UserInputConfig(word_embedding_tp=False)

        assert enabled_config.word_embedding_tp == WordEmbeddingTPMode.col
        assert enabled_config.get_parallel_config().embedding_parallel == WordEmbeddingTPMode.col
        assert disabled_config.word_embedding_tp is None
        assert disabled_config.get_parallel_config().embedding_parallel is None

    def test_word_embedding_tp_invalid_value_raises(self):
        with pytest.raises(ValueError, match="word_embedding_tp must be one of"):
            UserInputConfig(word_embedding_tp="invalid")

        with pytest.raises(ValueError, match="word_embedding_tp must be one of"):
            UserInputConfig(word_embedding_tp=123)