# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2026. All rights reserved.
# MindIE is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#         http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

import time
from enum import Enum

from pydantic import BaseModel, Field, PrivateAttr

import anyio

from motor.common.resources.instance import PDRole
from motor.coordinator.domain.scheduling_constraint import SchedulingConstraint
from motor.coordinator.tracer.tracing import TraceObj
from motor.coordinator.models.constants import OpenAIField


class RequestType(Enum):
    OPENAI = "openai"
    ANTHROPIC = "anthropic"
    TRITON = "triton"
    TGI = "tgi"
    VLLM = "vllm"
    MINDIE = "mindie"


class ReqState(Enum):
    ARRIVE = 'Arrive'  # Request arrive
    E_SCHEDULING = 'E_Scheduling'  # Currently scheduling E instance
    E_ALLOCATED = 'E_Allocated'  # Allocated E instance
    P_SCHEDULING = 'P_Scheduling'  # Currently scheduling P instance
    P_ALLOCATED = 'P_Allocated'  # Allocated P instance
    PREFILL_END = 'Prefill End'  # Prefill completed
    FIRST_TOKEN_FINISH = "First Token Finish"  # nosec B105 -- request lifecycle state label, not a credential
    D_SCHEDULING = 'D_Scheduling'  # Currently scheduling D instance
    D_ALLOCATED = 'D_Allocated'  # Allocated D instance
    DECODE_END = 'Decode End'  # Decode completed
    INVALID = 'Invalid'  # Invalid state
    TIMEOUT = 'Timeout'  # Request timeout
    EXCEPTION = 'Exception'  # Request exception
    RECOMPUTE = 'Recompute'  # Recomputation


class RequestInfo(BaseModel):
    req_id: str = Field(..., description="Request ID generated by RequestManager")
    req_data: dict = Field(..., description="Request json content")
    req_len: int = Field(..., description="Request body length")
    token_ids: list[int] | None = Field(
        default=None,
        description="Prompt token ids tokenized once at routing (KV affinity); reused for "
        "prefill load accounting so load and affinity share the same token unit",
    )
    kv_affinity_debug: dict | None = Field(
        default=None,
        exclude=True,
        description="Per-endpoint (matched_tokens, load_cost, prefill_cost) cached by the "
        "kv_cache_affinity policy at selection; the worker forwards prefill_cost for the "
        "scheduler's global fresh-load re-rank and logs matched/load for the committed endpoint. "
        "Keyed by (instance_id, endpoint_id) tuples, so excluded from serialization.",
    )
    api: str = Field(..., description="API need to be forwarded")
    entry_api: str = Field(
        default="",
        description="Original client HTTP path (e.g. v1/chat/completions); unchanged on recompute retry",
    )
    # to be deleted
    recompute_engine_mode: str | None = Field(
        default=None,
        description="Set to 'completions' when recompute retry uses v1/completions body",
    )
    client_expects_token_ids: bool = Field(
        default=False,
        description="True if the original client request contains a 'return_token_ids' parameter",
    )
    client_expects_chat_shape: bool = Field(
        default=False,
        description="True if the original client request was chat completions (messages present at ingress)",
    )
    state: ReqState = Field(default=ReqState.ARRIVE, description="Request current status")
    status: dict[ReqState, float] = Field(default={}, description="Request status time")
    trace_obj: TraceObj = Field(default_factory=TraceObj, description="Tracing object")
    _p_cancel_scope: anyio.CancelScope | None = PrivateAttr(default=None)
    _d_cancel_scope: anyio.CancelScope | None = PrivateAttr(default=None)
    _e_cancel_scope: anyio.CancelScope | None = PrivateAttr(default=None)
    prompt_tokens_details: dict = Field(default={}, description="prefill prompt_tokens_details")
    prompt_token_ids: list = Field(default=[], description="prefill prompt_token_ids")
    cached_token_ids: list = Field(default=[], description="Cached token_ids")
    p_instance_id: int | None = Field(
        default=None, description="P instance ID set by metaserver handler in CDP-like modes"
    )
    scheduling_constraint: SchedulingConstraint | None = Field(
        default=None,
        description="Internal pin-to-instance constraint (e.g. precision probe); not from client API",
    )

    def __init__(self, **data):
        super().__init__(**data)
        self.status[ReqState.ARRIVE] = time.time()

    @property
    def is_cancelled(self) -> bool:
        return (
            (self._p_cancel_scope and self._p_cancel_scope.cancel_called)
            or (self._d_cancel_scope and self._d_cancel_scope.cancel_called)
            or (self._e_cancel_scope and self._e_cancel_scope.cancel_called)
        )

    def effective_entry_api(self) -> str:
        """Path used for client-contract checks (Chat vs Completion); falls back to ``api``."""
        return self.entry_api or self.api

    def update_state(self, new_state: ReqState):
        self.state = new_state
        self.status[new_state] = time.time()

    def update_prompt_tokens_details(self, prompt_tokens_details: dict):
        self.prompt_tokens_details = prompt_tokens_details

    def set_cancel_scope(self, cancel_scope: anyio.CancelScope, role: PDRole):
        if role == PDRole.ROLE_P:
            self._p_cancel_scope = cancel_scope
        elif role == PDRole.ROLE_D:
            self._d_cancel_scope = cancel_scope
        elif role == PDRole.ROLE_E:
            self._e_cancel_scope = cancel_scope

    def cancel_scope(self):
        if self._p_cancel_scope and not self._p_cancel_scope.cancel_called:
            self._p_cancel_scope.cancel()
        if self._d_cancel_scope and not self._d_cancel_scope.cancel_called:
            self._d_cancel_scope.cancel()
        if self._e_cancel_scope and not self._e_cancel_scope.cancel_called:
            self._e_cancel_scope.cancel()

    def update_token_id_cache(self, chunk_json: dict) -> None:
        """Accumulate ``return_token_ids`` response fields and cached them.
        - Root ``prompt_token_ids``: set ``cached_prompt_token_ids`` once (first non-null list).
        - ``choices[0].prompt_token_ids`` (Completion stream): promoted when root is absent.
        - ``choices[0].token_ids``: extend ``cached_output_token_ids`` when a list.
        """
        pti = chunk_json.get(OpenAIField.PROMPT_TOKEN_IDS)
        if pti is None:
            choices = chunk_json.get(OpenAIField.CHOICES) or []
            if choices and isinstance(choices[0], dict):
                pti = choices[0].get(OpenAIField.PROMPT_TOKEN_IDS)
        if isinstance(pti, (list, tuple)) and len(self.prompt_token_ids) == 0:
            self.prompt_token_ids = list(pti)

        choices = chunk_json.get(OpenAIField.CHOICES) or []
        if not choices:
            return
        c0 = choices[0]
        token_ids = c0.get(OpenAIField.TOKEN_IDS)
        if isinstance(token_ids, list):
            self.cached_token_ids.extend(token_ids)