from __future__ import annotations

import itertools
import sys
from typing import Any

from tensor_cast.model_config import RemoteSource

from .schemas import ExperimentTask
from .utils import (
    normalize_value,
    parse_optional_number,
    parse_scalar_or_list,
    stable_hash,
)

OPT_DEPLOY_PD_MIXED = "PD Aggregated"
OPT_DEPLOY_PD_SPLIT = "PD Disaggregated"
OPT_DEPLOY_PD_RATIO = "PD Ratio"
DEFAULT_REMOTE_SOURCE = RemoteSource.huggingface.value

OPT_DEPLOY_MODE_ALIASES = {
    "": OPT_DEPLOY_PD_MIXED,
    "Aggregation": OPT_DEPLOY_PD_MIXED,
    "aggregation": OPT_DEPLOY_PD_MIXED,
    "PD Mixed": OPT_DEPLOY_PD_MIXED,
    "pd mixed": OPT_DEPLOY_PD_MIXED,
    "PD Aggregated": OPT_DEPLOY_PD_MIXED,
    "pd aggregated": OPT_DEPLOY_PD_MIXED,
    "PD \u6df7\u90e8": OPT_DEPLOY_PD_MIXED,
    OPT_DEPLOY_PD_MIXED: OPT_DEPLOY_PD_MIXED,
    "Disagg": OPT_DEPLOY_PD_SPLIT,
    "disagg": OPT_DEPLOY_PD_SPLIT,
    "PD Split": OPT_DEPLOY_PD_SPLIT,
    "pd split": OPT_DEPLOY_PD_SPLIT,
    "PD Disaggregated": OPT_DEPLOY_PD_SPLIT,
    "pd disaggregated": OPT_DEPLOY_PD_SPLIT,
    "PD \u5206\u79bb": OPT_DEPLOY_PD_SPLIT,
    OPT_DEPLOY_PD_SPLIT: OPT_DEPLOY_PD_SPLIT,
    OPT_DEPLOY_PD_RATIO: OPT_DEPLOY_PD_RATIO,
}


def _normalize_optimizer_deployment_mode(mode: Any) -> str:
    text = str(mode or "").strip()
    return OPT_DEPLOY_MODE_ALIASES.get(text, text)


def _device_matrix(primary_device: str, competitor_devices: list[str]) -> list[str]:
    devices = [d for d in [primary_device, *competitor_devices] if d]
    seen = set()
    out = []
    for device in devices:
        if device not in seen:
            seen.add(device)
            out.append(device)
    return out or [primary_device]


def _base_cmd(module_name: str) -> list[str]:
    return [sys.executable, "-m", module_name]


def _as_bool(value: Any) -> bool:
    return bool(value)


def _optional_str(value: Any) -> str | None:
    if value in (None, ""):
        return None
    return str(value)


def _performance_models(value: Any) -> list[str]:
    if not value:
        return ["analytic"]
    if isinstance(value, str):
        return parse_scalar_or_list(value, str)
    return [str(v) for v in value]


def build_text_generate_tasks(form: dict[str, Any]) -> list[ExperimentTask]:
    devices = _device_matrix(form["device"], form.get("competitor_devices", []))
    num_queries_list = parse_scalar_or_list(form.get("num_queries_sweep") or form["num_queries"], int)
    tp_size_list = parse_scalar_or_list(form.get("tp_sweep") or form.get("tp_size", 1), int)
    quant_linear_list = parse_scalar_or_list(form.get("quant_linear_sweep") or form["quantize_linear_action"], str)
    quant_attention_list = parse_scalar_or_list(
        form.get("quant_attention_sweep") or form["quantize_attention_action"], str
    )
    decode_values = [bool(form.get("decode", False))]

    tasks: list[ExperimentTask] = []
    for device, num_queries, tp_size, qlin, qattn, decode in itertools.product(
        devices,
        num_queries_list,
        tp_size_list,
        quant_linear_list,
        quant_attention_list,
        decode_values,
    ):
        params = {
            "model_id": form["model_id"],
            "device": device,
            "num_devices": int(form["num_devices"]),
            "num_queries": int(num_queries),
            "query_length": int(form["query_length"]),
            "context_length": int(form.get("context_length", 0) or 0),
            "decode": decode,
            "num_mtp_tokens": int(form.get("num_mtp_tokens", 0) or 0),
            "mtp_acceptance_rate": form.get("mtp_acceptance_rate", ""),
            "compile": bool(form.get("compile", False)),
            "quantize_linear_action": qlin,
            "quantize_attention_action": qattn,
            "tp_size": int(tp_size),
            "dp_size": parse_optional_number(form.get("dp_size"), int),
            "ep_size": int(form.get("ep_size", 1) or 1),
            "image_batch_size": parse_optional_number(form.get("image_batch_size"), int),
            "image_height": parse_optional_number(form.get("image_height"), int),
            "image_width": parse_optional_number(form.get("image_width"), int),
            "prefix_cache_hit_rate": float(form.get("prefix_cache_hit_rate") or 0.0),
            "reserved_memory_gb": float(form.get("reserved_memory_gb") or 0.0),
            "log_level": str(form.get("log_level") or "error"),
            "enable_multistream": _as_bool(form.get("enable_multistream", True)),
            "compile_allow_graph_break": _as_bool(form.get("compile_allow_graph_break", False)),
            "disable_repetition": _as_bool(form.get("disable_repetition", False)),
            "quantize_lmhead": _as_bool(form.get("quantize_lmhead", False)),
            "mxfp4_group_size": int(form.get("mxfp4_group_size") or 32),
            "graph_log_url": _optional_str(form.get("graph_log_url")),
            "dump_input_shapes": _as_bool(form.get("dump_input_shapes", False)),
            "chrome_trace": _optional_str(form.get("chrome_trace")),
            "num_hidden_layers_override": int(form.get("num_hidden_layers_override") or 0),
            "o_proj_tp_size": parse_optional_number(form.get("o_proj_tp_size"), int),
            "o_proj_dp_size": parse_optional_number(form.get("o_proj_dp_size"), int),
            "mlp_tp_size": parse_optional_number(form.get("mlp_tp_size"), int),
            "mlp_dp_size": parse_optional_number(form.get("mlp_dp_size"), int),
            "lmhead_tp_size": parse_optional_number(form.get("lmhead_tp_size"), int),
            "lmhead_dp_size": parse_optional_number(form.get("lmhead_dp_size"), int),
            "moe_tp_size": parse_optional_number(form.get("moe_tp_size"), int),
            "moe_dp_size": int(form.get("moe_dp_size") or 1),
            "word_embedding_tp": _optional_str(form.get("word_embedding_tp")),
            "enable_redundant_experts": _as_bool(form.get("enable_redundant_experts", False)),
            "enable_external_shared_experts": _as_bool(form.get("enable_external_shared_experts", False)),
            "host_external_shared_experts": _as_bool(form.get("host_external_shared_experts", False)),
            "enable_sequence_parallel": _as_bool(form.get("enable_sequence_parallel", False)),
            "enable_shared_expert_tp": _as_bool(form.get("enable_shared_expert_tp", False)),
            "enable_dispatch_ffn_combine": _as_bool(form.get("enable_dispatch_ffn_combine", False)),
            "remote_source": str(form.get("remote_source") or DEFAULT_REMOTE_SOURCE),
            "performance_model": _performance_models(form.get("performance_model")),
            "profiling_database": _optional_str(form.get("profiling_database")),
            "export_empirical_metrics": _optional_str(form.get("export_empirical_metrics")),
        }
        cmd = _base_cmd("cli.inference.text_generate")
        cmd += [
            params["model_id"],
            "--device",
            device,
            "--num-devices",
            str(params["num_devices"]),
        ]
        cmd += [
            "--num-queries",
            str(params["num_queries"]),
            "--query-length",
            str(params["query_length"]),
        ]
        cmd += ["--context-length", str(params["context_length"])]
        if params["decode"]:
            cmd.append("--decode")
        if params["num_mtp_tokens"] > 0:
            cmd += ["--num-mtp-tokens", str(params["num_mtp_tokens"])]
            # Add MTP acceptance-rate arguments
            if params["mtp_acceptance_rate"]:
                rates = [r.strip() for r in params["mtp_acceptance_rate"].split(",") if r.strip()]
                if rates:
                    cmd += ["--mtp-acceptance-rate"] + rates
        if params["prefix_cache_hit_rate"] > 0:
            cmd += ["--prefix-cache-hit-rate", str(params["prefix_cache_hit_rate"])]
        if params["disable_repetition"]:
            cmd.append("--disable-repetition")
        if params["compile"]:
            cmd.append("--compile")
        if params["enable_multistream"]:
            cmd.append("--enable-multistream")
        if params["compile_allow_graph_break"]:
            cmd.append("--compile-allow-graph-break")
        cmd += ["--quantize-linear-action", qlin, "--quantize-attention-action", qattn]
        if params["quantize_lmhead"]:
            cmd.append("--quantize-lmhead")
        if qlin == "MXFP4" and params["mxfp4_group_size"] != 32:
            cmd += ["--mxfp4-group-size", str(params["mxfp4_group_size"])]
        cmd += [
            "--tp-size",
            str(params["tp_size"]),
            "--ep-size",
            str(params["ep_size"]),
        ]
        if params["dp_size"] is not None:
            cmd += ["--dp-size", str(params["dp_size"])]
        for flag, key in [
            ("--o-proj-tp-size", "o_proj_tp_size"),
            ("--o-proj-dp-size", "o_proj_dp_size"),
            ("--mlp-tp-size", "mlp_tp_size"),
            ("--mlp-dp-size", "mlp_dp_size"),
            ("--lmhead-tp-size", "lmhead_tp_size"),
            ("--lmhead-dp-size", "lmhead_dp_size"),
            ("--moe-tp-size", "moe_tp_size"),
        ]:
            if params[key] is not None:
                cmd += [flag, str(params[key])]
        if params["moe_dp_size"] != 1:
            cmd += ["--moe-dp-size", str(params["moe_dp_size"])]
        if params["word_embedding_tp"]:
            cmd += ["--word-embedding-tp", params["word_embedding_tp"]]
        if params["enable_redundant_experts"]:
            cmd.append("--enable-redundant-experts")
        if params["enable_external_shared_experts"]:
            cmd.append("--enable-external-shared-experts")
        if params["host_external_shared_experts"]:
            cmd.append("--host-external-shared-experts")
        if params["enable_sequence_parallel"]:
            cmd.append("--enable-sequence-parallel")
        if params["enable_shared_expert_tp"]:
            cmd.append("--enable-shared-expert-tp")
        if params["enable_dispatch_ffn_combine"]:
            cmd.append("--enable-dispatch-ffn-combine")
        if params["image_batch_size"] is not None:
            cmd += ["--image-batch-size", str(params["image_batch_size"])]
        if params["image_height"] is not None:
            cmd += ["--image-height", str(params["image_height"])]
        if params["image_width"] is not None:
            cmd += ["--image-width", str(params["image_width"])]
        if params["remote_source"] != DEFAULT_REMOTE_SOURCE:
            cmd += ["--remote-source", params["remote_source"]]
        if params["reserved_memory_gb"] != 0.0:
            cmd += ["--reserved-memory-gb", str(params["reserved_memory_gb"])]
        if params["log_level"] != "error":
            cmd += ["--log-level", params["log_level"]]
        if params["graph_log_url"]:
            cmd += ["--graph-log-url", params["graph_log_url"]]
        if params["dump_input_shapes"]:
            cmd.append("--dump-input-shapes")
        if params["chrome_trace"]:
            cmd += ["--chrome-trace", params["chrome_trace"]]
        if params["num_hidden_layers_override"] != 0:
            cmd += [
                "--num-hidden-layers-override",
                str(params["num_hidden_layers_override"]),
            ]
        if params["performance_model"] != ["analytic"]:
            for perf_model in params["performance_model"]:
                cmd += ["--performance-model", perf_model]
        if params["profiling_database"]:
            cmd += ["--profiling-database", params["profiling_database"]]
        if params["export_empirical_metrics"]:
            cmd += ["--export-empirical-metrics", params["export_empirical_metrics"]]
        thash = stable_hash({"sim_type": "text_generate", **normalize_value(params)})
        label = (
            f"{params['model_id']} | {device} | nq={params['num_queries']} | tp={params['tp_size']} | {qlin}/{qattn}"
        )
        tasks.append(ExperimentTask("text_generate", params, cmd, thash, label))
    return tasks


def build_video_generate_tasks(form: dict[str, Any]) -> list[ExperimentTask]:
    devices = _device_matrix(form["device"], form.get("competitor_devices", []))
    quant_linear_list = parse_scalar_or_list(form.get("quant_linear_sweep") or form["quantize_linear_action"], str)
    ulysses_list = parse_scalar_or_list(form.get("ulysses_sweep") or form["ulysses_size"], int)
    tasks: list[ExperimentTask] = []
    for device, qlin, ulysses in itertools.product(devices, quant_linear_list, ulysses_list):
        params = {
            "model_id": form["model_id"],
            "remote_source": str(form.get("remote_source") or DEFAULT_REMOTE_SOURCE),
            "device": device,
            "batch_size": int(form["batch_size"]),
            "seq_len": int(form["seq_len"]),
            "height": int(form["height"]),
            "width": int(form["width"]),
            "frame_num": int(form["frame_num"]),
            "sample_step": int(form["sample_step"]),
            "dtype": str(form.get("dtype") or "float16"),
            "quantize_linear_action": qlin,
            "world_size": int(form["world_size"]),
            "ulysses_size": int(ulysses),
            "use_cfg": bool(form.get("use_cfg", False)),
            "cfg_parallel": bool(form.get("cfg_parallel", False)),
            "dit_cache": bool(form.get("dit_cache", False)),
            "cache_step_range": form.get("cache_step_range") or None,
            "cache_step_interval": parse_optional_number(form.get("cache_step_interval"), int) or 1,
            "cache_block_range": form.get("cache_block_range") or None,
            "chrome_trace": _optional_str(form.get("chrome_trace")),
            "log_level": str(form.get("log_level") or "info"),
        }
        cmd = _base_cmd("cli.inference.video_generate")
        cmd += [params["model_id"], "--device", device]
        cmd += [
            "--batch-size",
            str(params["batch_size"]),
            "--seq-len",
            str(params["seq_len"]),
        ]
        cmd += ["--height", str(params["height"]), "--width", str(params["width"])]
        cmd += [
            "--frame-num",
            str(params["frame_num"]),
            "--sample-step",
            str(params["sample_step"]),
        ]
        cmd += ["--dtype", params["dtype"], "--quantize-linear-action", qlin]
        cmd += [
            "--world-size",
            str(params["world_size"]),
            "--ulysses-size",
            str(params["ulysses_size"]),
        ]
        if params["use_cfg"]:
            cmd.append("--use-cfg")
        if params["cfg_parallel"]:
            cmd.append("--cfg-parallel")
        if params["dit_cache"]:
            cmd.append("--dit-cache")
            if params["cache_step_range"]:
                cmd += ["--cache-step-range", str(params["cache_step_range"])]
            if params["cache_step_interval"]:
                cmd += ["--cache-step-interval", str(params["cache_step_interval"])]
            if params["cache_block_range"]:
                cmd += ["--cache-block-range", str(params["cache_block_range"])]
        if params["chrome_trace"]:
            cmd += ["--chrome-trace", params["chrome_trace"]]
        if params["log_level"] != "info":
            cmd += ["--log-level", params["log_level"]]
        if params["remote_source"] != DEFAULT_REMOTE_SOURCE:
            cmd += ["--remote-source", params["remote_source"]]
        thash = stable_hash({"sim_type": "video_generate", **normalize_value(params)})
        label = f"{params['model_id']} | {device} | usp={params['ulysses_size']} | {qlin}"
        tasks.append(ExperimentTask("video_generate", params, cmd, thash, label))
    return tasks


def _mode_name(ttft, tpot):
    if ttft is None and tpot is None:
        return "offline"
    if ttft is not None and tpot is not None:
        return "ttft_tpot_constrained"
    if ttft is not None:
        return "ttft_constrained"
    return "tpot_constrained"


def build_optimizer_tasks(form: dict[str, Any]) -> list[ExperimentTask]:
    devices = _device_matrix(form["device"], form.get("competitor_devices", []))
    quant_linear_list = parse_scalar_or_list(form.get("quant_linear_sweep") or form["quantize_linear_action"], str)
    quant_attention_list = parse_scalar_or_list(
        form.get("quant_attention_sweep") or form["quantize_attention_action"], str
    )
    tpot_list = parse_scalar_or_list(form.get("tpot_sweep") or form.get("tpot_limits") or "None", str)
    ttft_list = parse_scalar_or_list(form.get("ttft_sweep") or form.get("ttft_limits") or "None", str)

    tp_sizes_str = form.get("tp_sizes", "")
    tp_sizes = parse_scalar_or_list(tp_sizes_str, int) if tp_sizes_str else None

    ep_sizes_str = form.get("ep_sizes", "")
    ep_sizes = parse_scalar_or_list(ep_sizes_str, int) if ep_sizes_str else None

    moe_dp_sizes_str = form.get("moe_dp_sizes", "")
    moe_dp_sizes = parse_scalar_or_list(moe_dp_sizes_str, int) if moe_dp_sizes_str else None

    batch_range_str = form.get("batch_range", "")
    batch_range = parse_scalar_or_list(batch_range_str, int) if batch_range_str else None

    concurrency_search_strategy = str(form.get("concurrency_search_strategy") or "exponential")

    jobs = int(form.get("jobs") or 8)
    deployment_mode = _normalize_optimizer_deployment_mode(form.get("deployment_mode"))
    disagg = deployment_mode == OPT_DEPLOY_PD_SPLIT
    enable_pd_ratio = deployment_mode == OPT_DEPLOY_PD_RATIO or bool(
        form.get("enable_optimize_prefill_decode_ratio", False)
    )
    compile_allow_graph_break = bool(form.get("compile_allow_graph_break", False))
    enable_multistream = bool(form.get("enable_multistream", True))
    mxfp4_group_size = int(form.get("mxfp4_group_size") or 32)
    prefix_cache_hit_rate = float(form.get("prefix_cache_hit_rate") or 0.0)

    prefill_devices_per_instance = parse_optional_number(form.get("prefill_devices_per_instance"), int)
    decode_devices_per_instance = parse_optional_number(form.get("decode_devices_per_instance"), int)
    if not enable_pd_ratio:
        prefill_devices_per_instance = None
        decode_devices_per_instance = None

    tasks: list[ExperimentTask] = []
    for device, qlin, qattn, tpot_raw, ttft_raw in itertools.product(
        devices, quant_linear_list, quant_attention_list, tpot_list, ttft_list
    ):
        tpot = parse_optional_number(tpot_raw, float)
        ttft = parse_optional_number(ttft_raw, float)

        num_mtp_tokens = int(form.get("num_mtp_tokens") or 0)
        mtp_acceptance_rate_str = form.get("mtp_acceptance_rate") or "0.9,0.6,0.4,0.2"
        mtp_acceptance_rate = [float(r.strip()) for r in mtp_acceptance_rate_str.split(",") if r.strip()]
        max_batched_tokens = int(form.get("max_batched_tokens") or 8192)

        params = {
            "model_id": form["model_id"],
            "device": device,
            "num_devices": int(form["num_devices"]),
            "input_length": int(form["input_length"]),
            "output_length": int(form["output_length"]),
            "compile": bool(form.get("compile", False)),
            "quantize_linear_action": qlin,
            "quantize_attention_action": qattn,
            "tpot_limits": tpot,
            "ttft_limits": ttft,
            "num_mtp_tokens": num_mtp_tokens,
            "mtp_acceptance_rate": mtp_acceptance_rate,
            "max_batched_tokens": max_batched_tokens,
            "image_batch_size": parse_optional_number(form.get("image_batch_size"), int),
            "image_height": parse_optional_number(form.get("image_height"), int),
            "image_width": parse_optional_number(form.get("image_width"), int),
            "optimization_mode": _mode_name(ttft, tpot),
            "deployment_mode": deployment_mode,
            "tp_sizes": tp_sizes,
            "ep_sizes": ep_sizes,
            "moe_dp_sizes": moe_dp_sizes,
            "batch_range": batch_range,
            "concurrency_search_strategy": concurrency_search_strategy,
            "jobs": jobs,
            "disagg": disagg,
            "prefix_cache_hit_rate": prefix_cache_hit_rate,
            "prefill_devices_per_instance": prefill_devices_per_instance,
            "decode_devices_per_instance": decode_devices_per_instance,
            "enable_optimize_prefill_decode_ratio": enable_pd_ratio,
            "compile_allow_graph_break": compile_allow_graph_break,
            "enable_multistream": enable_multistream,
            "mxfp4_group_size": mxfp4_group_size,
            "reserved_memory_gb": float(form.get("reserved_memory_gb") or 0.0),
            "log_level": str(form.get("log_level") or "error"),
            "serving_cost": float(form.get("serving_cost") or 0.0),
            "dump_original_results": _as_bool(form.get("dump_original_results", False)),
        }
        cmd = _base_cmd("cli.inference.throughput_optimizer")
        cmd += [
            params["model_id"],
            "--device",
            device,
            "--num-devices",
            str(params["num_devices"]),
        ]
        cmd += [
            "--input-length",
            str(params["input_length"]),
            "--output-length",
            str(params["output_length"]),
        ]
        if params["compile"]:
            cmd.append("--compile")
        if enable_multistream:
            cmd.append("--enable-multistream")
        if compile_allow_graph_break:
            cmd.append("--compile-allow-graph-break")
        cmd += ["--quantize-linear-action", qlin, "--quantize-attention-action", qattn]
        if tpot is not None:
            cmd += ["--tpot-limits", str(tpot)]
        if ttft is not None:
            cmd += ["--ttft-limits", str(ttft)]
        if tp_sizes:
            cmd += ["--tp-sizes"] + [str(t) for t in tp_sizes]
        if ep_sizes:
            cmd += ["--ep-sizes"] + [str(t) for t in ep_sizes]
        if moe_dp_sizes:
            cmd += ["--moe-dp-sizes"] + [str(t) for t in moe_dp_sizes]
        if batch_range:
            cmd += ["--batch-range"] + [str(b) for b in batch_range]
        if jobs != 8:
            cmd += ["--jobs", str(jobs)]
        if params["serving_cost"] != 0.0:
            cmd += ["--serving-cost", str(params["serving_cost"])]
        if params["reserved_memory_gb"] != 0.0:
            cmd += ["--reserved-memory-gb", str(params["reserved_memory_gb"])]
        if params["log_level"] != "error":
            cmd += ["--log-level", params["log_level"]]
        if params["dump_original_results"]:
            cmd.append("--dump-original-results")
        if prefix_cache_hit_rate > 0:
            cmd += ["--prefix-cache-hit-rate", str(prefix_cache_hit_rate)]
        if disagg:
            cmd.append("--disagg")
        if enable_pd_ratio:
            cmd.append("--enable-optimize-prefill-decode-ratio")
            if prefill_devices_per_instance is not None:
                cmd += [
                    "--prefill-devices-per-instance",
                    str(prefill_devices_per_instance),
                ]
            if decode_devices_per_instance is not None:
                cmd += [
                    "--decode-devices-per-instance",
                    str(decode_devices_per_instance),
                ]
        if qlin == "MXFP4" and mxfp4_group_size != 32:
            cmd += ["--mxfp4-group-size", str(mxfp4_group_size)]
        if num_mtp_tokens > 0:
            cmd += ["--num-mtp-tokens", str(num_mtp_tokens)]
            cmd += ["--mtp-acceptance-rate"] + [str(r) for r in mtp_acceptance_rate]
        if max_batched_tokens != 8192:
            cmd += ["--max-batched-tokens", str(max_batched_tokens)]
        if params["image_batch_size"] is not None:
            cmd += ["--image-batch-size", str(params["image_batch_size"])]
        if params["image_height"] is not None:
            cmd += ["--image-height", str(params["image_height"])]
        if params["image_width"] is not None:
            cmd += ["--image-width", str(params["image_width"])]
        if concurrency_search_strategy != "exponential":
            cmd += ["--concurrency-search-strategy", concurrency_search_strategy]

        thash = stable_hash({"sim_type": "throughput_optimizer", **normalize_value(params)})
        label = f"{params['model_id']} | {device} | {params['optimization_mode']} | {deployment_mode} | {qlin}/{qattn}"
        if num_mtp_tokens > 0:
            label += f" | mtp={num_mtp_tokens}"
        if prefix_cache_hit_rate > 0:
            label += f" | cache={prefix_cache_hit_rate:g}"
        if enable_pd_ratio and prefill_devices_per_instance is not None and decode_devices_per_instance is not None:
            label += f" | p:d={prefill_devices_per_instance}:{decode_devices_per_instance}"
        tasks.append(ExperimentTask("throughput_optimizer", params, cmd, thash, label))
    return tasks