from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import os
import subprocess
import time

from user_config_loader import UserConfig
from utils import convert_args_dict_to_list, resolve_with_retry

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

MAX_RESOLVE_ATTEMPTS = 5
RESOLVE_DELAY = 10
MAX_WORKERS = 20


def _validate_router_config(user_config: UserConfig) -> int:
    if user_config.router_config is None:
        raise ValueError("router_config must not be None")

    if 'port' not in user_config.router_config:
        raise ValueError("router_config must contain 'port' field")

    port = user_config.router_config['port']
    if not isinstance(port, int) or port <= 0 or port > 65535:
        raise ValueError(f"port must be a valid positive integer between 1 and 65535, got: {port}")

    return port


def get_prefiller_or_decoder_hosts(user_config: UserConfig, role: str) -> list:
    infer_service_name = os.environ.get('INFER_SERVICE_NAME')
    infer_service_index = os.environ.get('INFER_SERVICE_INDEX')
    namespace = user_config.deploy_config.namespace
    if role == 'prefill':
        instance_count = user_config.deploy_config.prefill.instance_count
    elif role == 'decode':
        instance_count = user_config.deploy_config.decode.instance_count
    else:
        raise ValueError(f"Unsupported role: {role}")

    hostnames = []
    for instance_index in range(instance_count):
        hostname = f"{infer_service_name}-{infer_service_index}-{role}-{instance_index}-0.service-{infer_service_name}-{infer_service_index}-{role}-{instance_index}.{namespace}.svc.cluster.local"
        hostnames.append(hostname)

    def resolve_hostname(hostname: str) -> str:
        max_resolve_attempts = MAX_RESOLVE_ATTEMPTS
        resolve_delay = RESOLVE_DELAY

        for attempt in range(max_resolve_attempts):
            ip = resolve_with_retry(hostname)
            if ip is not None:
                return ip
            logging.debug(
                "Attempt %s/%s failed to resolve hostname %s. Retrying in %s seconds...",
                attempt + 1,
                max_resolve_attempts,
                hostname,
                resolve_delay
            )
            if attempt < max_resolve_attempts - 1:
                time.sleep(resolve_delay)

        raise ValueError(f"Failed to resolve hostname {hostname} after {max_resolve_attempts} attempts")

    result = []
    max_workers = min(instance_count, MAX_WORKERS)
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_hostname = {
            executor.submit(resolve_hostname, hostname): hostname
            for hostname in hostnames
        }

        for future in as_completed(future_to_hostname):
            try:
                ip = future.result()
                result.append(ip)
            except Exception as e:
                raise e

    return result


def get_prefiller_or_decoder_ports(user_config: UserConfig, role: str) -> list:
    if role == 'prefill':
        port_num = user_config.deploy_config.prefill.instance_count
    elif role == 'decode':
        port_num = user_config.deploy_config.decode.instance_count
    else:
        raise ValueError(f"Unsupported role: {role}")

    port = user_config.engine_common_config.server_port
    port_list = [port] * port_num

    return port_list


def run_router(config_path):
    try:
        user_config = UserConfig.load_from_file(config_path)
        port = _validate_router_config(user_config)
        host = os.environ.get('POD_IP')
        args_dict = {}
        args_dict['port'] = port
        args_dict['host'] = host
        args_dict['prefiller_hosts'] = get_prefiller_or_decoder_hosts(user_config, 'prefill')
        args_dict['prefiller_ports'] = get_prefiller_or_decoder_ports(user_config, 'prefill')
        args_dict['decoder_hosts'] = get_prefiller_or_decoder_hosts(user_config, 'decode')
        args_dict['decoder_ports'] = get_prefiller_or_decoder_ports(user_config, 'decode')

        converted_args_list = convert_args_dict_to_list(args_dict)
        current_dir = os.path.dirname(__file__)
        script_path = os.path.join(current_dir, 'load_balance_proxy_layerwise_server_example.py')
        router_cmd = ['python', script_path] + converted_args_list
        logging.info(f"Starting router with command: {' '.join(router_cmd)}")

        process = subprocess.Popen(router_cmd, shell=False)
        stdout, stderr = process.communicate()
        if process.returncode != 0:
            logging.error(f"Router process failed with return code {process.returncode}")
            logging.error(f"stdout: {stdout}")
            logging.error(f"stderr: {stderr}")
            raise subprocess.CalledProcessError(process.returncode, router_cmd, output=stdout, stderr=stderr)

    except Exception as e:
        logging.error(f"Error in run_router: {e}")
        raise