from __future__ import annotations

import dataclasses
from collections.abc import Mapping, Sequence
from typing import Literal

import sympy


JsonScalar = int | float | bool | str | None


@dataclasses.dataclass(frozen=True, slots=True)
class GroupFeatureSpec:
    name: str
    source: Literal["primary_axis", "outer_product", "reduction_product", "axis"]
    axis_names: tuple[str, ...]
    buckets: tuple[int, ...]


@dataclasses.dataclass(frozen=True, slots=True)
class KernelVariant:
    variant_id: str
    constexpr_kwargs: tuple[tuple[str, int], ...]
    num_warps: int
    num_stages: int
    extra_compile_kwargs: tuple[tuple[str, JsonScalar], ...]


@dataclasses.dataclass(frozen=True, slots=True)
class LaunchPolicy:
    policy_id: str
    group_id: int
    static_blocks: tuple[tuple[str, int], ...]
    runtime_block_rules: tuple[tuple[str, tuple[tuple[str, JsonScalar], ...]], ...]
    grid_target: int


@dataclasses.dataclass(frozen=True, slots=True)
class GroupedCandidate:
    group_id: int
    variant_id: str
    policy_id: str


class UnsupportedGroupedPlan(RuntimeError):
    pass


@dataclasses.dataclass(frozen=True, slots=True)
class GroupedKernelMeta:
    enabled: bool
    template: str
    primary_group_axis: str | None
    static_split_axes: tuple[str, ...]
    secondary_runtime_symbolic_axes: tuple[str, ...]
    group_features: tuple[GroupFeatureSpec, ...]
    runtime_block_arg_names: tuple[str, ...]

    def to_payload(self) -> dict[str, object]:
        return {
            "enabled": self.enabled,
            "template": self.template,
            "primary_group_axis": self.primary_group_axis,
            "static_split_axes": list(self.static_split_axes),
            "secondary_runtime_symbolic_axes": list(
                self.secondary_runtime_symbolic_axes
            ),
            "group_features": [
                {
                    "name": spec.name,
                    "source": spec.source,
                    "axis_names": list(spec.axis_names),
                    "buckets": list(spec.buckets),
                }
                for spec in self.group_features
            ],
            "runtime_block_arg_names": list(self.runtime_block_arg_names),
        }

    @staticmethod
    def from_payload(payload: Mapping[str, object]) -> "GroupedKernelMeta":
        return GroupedKernelMeta(
            enabled=_require_bool(payload.get("enabled"), "enabled"),
            template=_require_str(payload.get("template"), "template"),
            primary_group_axis=_require_optional_str(
                payload.get("primary_group_axis"), "primary_group_axis"
            ),
            static_split_axes=_require_str_tuple(
                payload.get("static_split_axes"), "static_split_axes"
            ),
            secondary_runtime_symbolic_axes=_require_str_tuple(
                payload.get(
                    "secondary_runtime_symbolic_axes"
                ),
                "secondary_runtime_symbolic_axes",
            ),
            group_features=tuple(
                _group_feature_spec_from_payload(item)
                for item in _require_sequence(
                    payload.get("group_features"), "group_features"
                )
            ),
            runtime_block_arg_names=_require_str_tuple(
                payload.get("runtime_block_arg_names"), "runtime_block_arg_names"
            ),
        )


def is_runtime_symbolic_length(length) -> bool:
    return not isinstance(length, (int, sympy.Integer))


def bucketize(value: int, buckets: tuple[int, ...]) -> int:
    for idx, upper in enumerate(buckets):
        if value <= upper:
            return idx
    return len(buckets)


def make_group_id(
    feature_values: tuple[int, ...], feature_specs: tuple[GroupFeatureSpec, ...]
) -> int:
    if len(feature_values) != len(feature_specs):
        raise ValueError(
            "feature_values and feature_specs must have the same length: "
            f"{len(feature_values)} != {len(feature_specs)}"
        )

    group_id = 0
    stride = 1
    for value, spec in zip(feature_values, feature_specs):
        group_id += bucketize(value, spec.buckets) * stride
        stride *= len(spec.buckets) + 1
    return group_id


def decode_group_bucket_indices(
    group_features, group_id: int
) -> tuple[int, ...]:
    bucket_indices = []
    remaining = int(group_id)
    for feature_spec in tuple(group_features):
        buckets = tuple(_feature_field(feature_spec, "buckets"))
        radix = len(buckets) + 1
        bucket_indices.append(remaining % radix)
        remaining //= radix
    return tuple(bucket_indices)


def find_primary_feature_index(group_features, primary_group_axis: str) -> int:
    for feature_idx, feature_spec in enumerate(tuple(group_features)):
        axis_names = tuple(_feature_field(feature_spec, "axis_names"))
        if primary_group_axis in axis_names:
            return feature_idx
    raise ValueError(
        f"primary_group_axis {primary_group_axis} does not appear in group_features"
    )


def is_open_bucket_group(
    group_features,
    primary_feature_index: int,
    group_id: int,
) -> bool:
    feature_specs = tuple(group_features)
    bucket_indices = decode_group_bucket_indices(feature_specs, group_id)
    buckets = tuple(_feature_field(feature_specs[primary_feature_index], "buckets"))
    return bucket_indices[primary_feature_index] == len(buckets)


def _feature_field(feature_spec, field_name: str):
    if isinstance(feature_spec, Mapping):
        return feature_spec[field_name]
    return getattr(feature_spec, field_name)


def build_group_representatives(
    group_features,
    axis_names,
    axis_static_values,
) -> dict[str, object]:
    group_features = tuple(group_features)
    axis_names = tuple(axis_names)
    axis_static_values = {
        axis_name: int(axis_value)
        for axis_name, axis_value in axis_static_values
    }

    seen_feature_axes = set()
    for feature_spec in group_features:
        for axis_name in tuple(_feature_field(feature_spec, "axis_names")):
            if axis_name in seen_feature_axes:
                raise UnsupportedGroupedPlan(
                    f"axis {axis_name} appears in multiple group features"
                )
            seen_feature_axes.add(axis_name)

    def bucket_bounds(feature_spec, bucket_idx: int):
        buckets = tuple(_feature_field(feature_spec, "buckets"))
        if not buckets:
            raise UnsupportedGroupedPlan(
                f"group feature {_feature_field(feature_spec, 'name')} must define non-empty buckets"
            )
        lower = int(buckets[bucket_idx - 1]) if bucket_idx > 0 else 0
        upper = int(buckets[bucket_idx]) if bucket_idx < len(buckets) else None
        return lower, upper

    def representative_feature_value(feature_spec, bucket_idx: int) -> int:
        lower, upper = bucket_bounds(feature_spec, bucket_idx)
        if upper is None:
            return lower * 2
        return upper

    def choose_symbolic_axis_value(static_factor: int, bucket_idx: int, feature_spec):
        lower, upper = bucket_bounds(feature_spec, bucket_idx)
        dyn_min = lower // static_factor + 1
        if upper is None:
            target = max(lower + 1, lower * 2)
            return max(1, (target + static_factor - 1) // static_factor)
        dyn_max = upper // static_factor
        if dyn_min > dyn_max:
            return None
        return max(1, dyn_max)

    def representative_for_feature(feature_spec, bucket_idx: int):
        feature_axis_names = tuple(_feature_field(feature_spec, "axis_names"))
        source = _feature_field(feature_spec, "source")
        if not feature_axis_names:
            return representative_feature_value(feature_spec, bucket_idx), ()
        if source in ("primary_axis", "axis"):
            if len(feature_axis_names) != 1:
                raise UnsupportedGroupedPlan(
                    f"group feature {_feature_field(feature_spec, 'name')} expects one axis"
                )
            value = representative_feature_value(feature_spec, bucket_idx)
            return value, ((feature_axis_names[0], value),)
        if source not in ("outer_product", "reduction_product"):
            raise UnsupportedGroupedPlan(
                f"group feature source {source} is not supported"
            )

        static_factor = 1
        static_axis_values = []
        symbolic_axis_names = []
        for axis_name in feature_axis_names:
            if axis_name in axis_static_values:
                axis_value = axis_static_values[axis_name]
                if axis_value <= 0:
                    raise UnsupportedGroupedPlan(
                        f"axis {axis_name} has non-positive static value {axis_value}"
                    )
                static_factor *= axis_value
                static_axis_values.append((axis_name, axis_value))
            else:
                symbolic_axis_names.append(axis_name)

        if len(symbolic_axis_names) > 1:
            raise UnsupportedGroupedPlan(
                f"group feature {_feature_field(feature_spec, 'name')} has multiple symbolic axes"
            )
        if len(symbolic_axis_names) == 0:
            feature_value = static_factor
            lower, upper = bucket_bounds(feature_spec, bucket_idx)
            if feature_value <= lower or (upper is not None and feature_value > upper):
                return None
            return feature_value, tuple(static_axis_values)

        symbolic_axis_value = choose_symbolic_axis_value(
            static_factor, bucket_idx, feature_spec
        )
        if symbolic_axis_value is None:
            return None
        axis_values = [
            (axis_name, symbolic_axis_value)
            if axis_name == symbolic_axis_names[0]
            else (axis_name, axis_static_values[axis_name])
            for axis_name in feature_axis_names
        ]
        return symbolic_axis_value * static_factor, tuple(axis_values)

    def representative_for_group(group_id: int):
        feature_values = []
        axis_values = {}
        remaining = group_id
        for feature_spec in group_features:
            buckets = tuple(_feature_field(feature_spec, "buckets"))
            radix = len(buckets) + 1
            bucket_idx = remaining % radix
            remaining //= radix
            representative = representative_for_feature(feature_spec, bucket_idx)
            if representative is None:
                return None
            feature_value, feature_axis_values = representative
            feature_values.append(feature_value)
            for axis_name, axis_value in feature_axis_values:
                axis_values[axis_name] = int(axis_value)

        for axis_name in axis_names:
            if axis_name in axis_static_values:
                axis_values.setdefault(axis_name, axis_static_values[axis_name])
            else:
                axis_values.setdefault(axis_name, 1)

        return (
            tuple(feature_values),
            tuple((axis_name, int(axis_values[axis_name])) for axis_name in axis_names),
        )

    group_id_count = 1
    for feature_spec in group_features:
        group_id_count *= len(tuple(_feature_field(feature_spec, "buckets"))) + 1

    representatives = tuple(
        representative_for_group(group_id) for group_id in range(group_id_count)
    )
    reachable_group_ids = tuple(
        group_id
        for group_id, representative in enumerate(representatives)
        if representative is not None
    )
    if not reachable_group_ids:
        raise UnsupportedGroupedPlan("grouped plan has no reachable groups")

    plan = {
        "group_id_count": group_id_count,
        "reachable_group_ids": reachable_group_ids,
        "unreachable_group_ids": tuple(
            group_id
            for group_id, representative in enumerate(representatives)
            if representative is None
        ),
        "benchmark_feature_inputs_by_group": tuple(
            representative[0] if representative is not None else ()
            for representative in representatives
        ),
        "benchmark_axis_values_by_group": tuple(
            representative[1] if representative is not None else ()
            for representative in representatives
        ),
    }
    return plan
def serialize_grouped_plan(plan: GroupedKernelMeta) -> dict[str, object]:
    return plan.to_payload()


def deserialize_grouped_plan(payload: dict[str, object]) -> GroupedKernelMeta:
    return GroupedKernelMeta.from_payload(payload)


def _group_feature_spec_from_payload(payload: object) -> GroupFeatureSpec:
    mapping = _require_mapping(payload, "group_features[]")
    return GroupFeatureSpec(
        name=_require_str(mapping.get("name"), "group_features[].name"),
        source=_require_str(mapping.get("source"), "group_features[].source"),
        axis_names=_require_str_tuple(
            mapping.get("axis_names"), "group_features[].axis_names"
        ),
        buckets=_require_int_tuple(mapping.get("buckets"), "group_features[].buckets"),
    )


def _require_mapping(value: object, field_name: str) -> Mapping[str, object]:
    if not isinstance(value, Mapping):
        raise TypeError(f"{field_name} must be a mapping, got {type(value).__name__}")
    return value


def _require_sequence(value: object, field_name: str) -> Sequence[object]:
    if not isinstance(value, Sequence) or isinstance(value, (str, bytes, bytearray)):
        raise TypeError(f"{field_name} must be a sequence, got {type(value).__name__}")
    return value


def _require_str(value: object, field_name: str) -> str:
    if not isinstance(value, str):
        raise TypeError(f"{field_name} must be a string, got {type(value).__name__}")
    return value


def _require_optional_str(value: object, field_name: str) -> str | None:
    if value is None:
        return None
    return _require_str(value, field_name)


def _require_bool(value: object, field_name: str) -> bool:
    if not isinstance(value, bool):
        raise TypeError(f"{field_name} must be a bool, got {type(value).__name__}")
    return value


def _require_str_tuple(value: object, field_name: str) -> tuple[str, ...]:
    items = _require_sequence(value, field_name)
    return tuple(_require_str(item, field_name) for item in items)


def _require_int_tuple(value: object, field_name: str) -> tuple[int, ...]:
    items = _require_sequence(value, field_name)
    result = []
    for item in items:
        if not isinstance(item, int):
            raise TypeError(
                f"{field_name} must contain ints, got {type(item).__name__}"
            )
        result.append(item)
    return tuple(result)