use std::{path::PathBuf, str::FromStr};

use clap::Subcommand;
use ek_base::{
    config::get_ek_settings,
    error::{EKError, EKResult},
};
use ek_computation::{
    proto::ek::control::v1::{
        DuplicateReq, ManualReq, RebalanceReq, plan_service_client::PlanServiceClient,
    },
    state::{
        io::StateReaderImpl,
        models::{NewExpert, NewInstance, NewNode},
        writer::StateWriterImpl,
    },
};
use ek_db::{safetensor::ExpertKey, weight_srv::client::WeightSrvClient};
use indicatif::ProgressBar;
use log::info;
use rand::random;
use serde::Deserialize;
use tokio::task::JoinSet;
use tonic::transport::Endpoint;

#[derive(Subcommand, Debug)]
pub enum ScheduleCommand {
    Static {
        #[arg(long, short, help = "file that contains worker nodes information")]
        inventory: PathBuf,
    },
    Rebalance,
    Duplicate {
        #[arg(
            long,
            help = "Specific hostnames to duplicate to (if not specified, duplicates to all nodes)",
            value_delimiter = ','
        )]
        hostnames: Option<Vec<String>>,
    },
    Manual {
        #[arg(long, help = "Hostnames of the target nodes", value_delimiter = ',')]
        hostnames: Vec<String>,
        #[arg(long, help = "Layer ranges to assign (e.g., '1-2,4-8')")]
        layers: String,
    },
}

pub async fn execute_schedule(cmd: ScheduleCommand) -> EKResult<()> {
    match cmd {
        ScheduleCommand::Static { inventory } => {
            execute_static_schedule(inventory).await?;
            Ok(())
        }

        ScheduleCommand::Rebalance => {
            execute_rebalance().await?;
            Ok(())
        }

        ScheduleCommand::Duplicate { hostnames } => {
            execute_duplicate(hostnames).await?;
            Ok(())
        }

        ScheduleCommand::Manual { hostnames, layers } => {
            execute_manual(hostnames, layers).await?;
            Ok(())
        }
    }
}

async fn execute_rebalance() -> EKResult<()> {
    let settings = get_ek_settings();
    let controller_addr = format!(
        "http://{}:{}",
        settings.controller.broadcast, settings.controller.ports.inter
    );
    log::info!("connect to controller at {controller_addr}");
    let endpoint = Endpoint::from_str(controller_addr.as_str()).unwrap();
    let mut cli = PlanServiceClient::connect(endpoint).await?;
    cli.rebalance(RebalanceReq {}).await?;
    log::info!("rebalance done");
    Ok(())
}

async fn execute_duplicate(hostnames: Option<Vec<String>>) -> EKResult<()> {
    let settings = get_ek_settings();
    let controller_addr = format!(
        "http://{}:{}",
        settings.controller.broadcast, settings.controller.ports.inter
    );
    log::info!("connect to controller at {controller_addr}");
    let endpoint = Endpoint::from_str(controller_addr.as_str()).unwrap();
    let mut cli = PlanServiceClient::connect(endpoint).await?;
    cli.duplicate(DuplicateReq {
        hostnames: hostnames.unwrap_or_default(),
    })
    .await?;
    log::info!("duplicate done");
    Ok(())
}

async fn execute_manual(hostnames: Vec<String>, layers: String) -> EKResult<()> {
    let settings = get_ek_settings();
    let controller_addr = format!(
        "http://{}:{}",
        settings.controller.broadcast, settings.controller.ports.inter
    );
    log::info!("connect to controller at {controller_addr}");
    let endpoint = Endpoint::from_str(controller_addr.as_str()).unwrap();
    let mut cli = PlanServiceClient::connect(endpoint).await?;
    cli.manual(ManualReq { hostnames, layers }).await?;
    log::info!("manual assignment done");
    Ok(())
}

#[derive(Debug, Clone, Deserialize)]
pub struct Node {
    pub id: String,
    pub address: String,
    pub channel: String,
    pub device: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Inventory {
    pub nodes: Vec<Node>,
}

async fn upsert_nodes(inventory: PathBuf) -> EKResult<Vec<i32>> {
    if !inventory.exists() {
        log::error!("inventory file not exists");
        return Err(EKError::NotFound("inventory file".to_string()));
    }

    let writer = StateWriterImpl::new();
    let contents = tokio::fs::read(inventory).await?;
    let inventory = serde_yaml::from_slice::<Inventory>(&contents).map_err(|e| {
        log::error!("failed to parse inventory file: {e}");
        EKError::InvalidInput("inventory file".to_string())
    })?;
    let mut node_ids = vec![];
    for node in inventory.nodes {
        let new_node = NewNode {
            hostname: node.id.clone(),
            device: node.device.clone(),
            config: serde_json::json!({
                "addr": node.address,
                "channel": node.channel,
            }),
        };
        let node = writer.node_upsert(new_node).await?;
        node_ids.push(node.id);
    }
    Ok(node_ids)
}

async fn execute_static_schedule(inventory: PathBuf) -> EKResult<()> {
    // implement the static scheduling logic here
    let settings = get_ek_settings();
    let model_name = settings.inference.model_name.clone();
    let instance_name = settings.inference.instance_name.clone();
    let ws_addr = settings.weight.server.as_ref().unwrap().addr.clone();
    info!(
        "Running static schedule for model: {model_name}, instance: {instance_name}, weight server: {ws_addr}"
    );
    let cli = WeightSrvClient::new(ws_addr);
    let vital = cli.load_meta_vital(&model_name).await?;
    info!("model info : {:?}", &vital);

    let reader = StateReaderImpl::new();
    let model = reader
        .model_by_name(&model_name)
        .await?
        .ok_or(EKError::NotFound("model not found".to_string()))?;

    let writer = StateWriterImpl::new();
    let node_ids = upsert_nodes(inventory).await?;

    let instance_obj = writer
        .instance_upsert(NewInstance {
            model_id: model.id,
            name: instance_name,
        })
        .await?;

    let mut experts = vec![];
    for layer in vital.moe_layers.0..vital.moe_layers.1 {
        for expert in 0..vital.routed_experts {
            experts.push(ExpertKey::new(model_name.clone(), layer, expert));
        }
    }
    log::info!("total experts to schedule {}", experts.len());

    let pb = ProgressBar::new(experts.len() as u64);

    writer.expert_del_by_instance(instance_obj.id).await?;

    let mut js = JoinSet::new();
    for e in experts {
        let e = e.clone();
        let node_ids = node_ids.clone();
        let p = pb.clone();
        js.spawn(async move {
            let writer = StateWriterImpl::new();
            let rand = random::<u16>();
            writer
                .expert_upsert(NewExpert {
                    instance_id: instance_obj.id,
                    node_id: node_ids[(rand % node_ids.len() as u16) as usize],
                    expert_id: e.as_object_key(),
                    replica: 1,
                    state: serde_json::json!({}),
                })
                .await
                .unwrap();
            p.inc(1);
        });
    }
    js.join_all().await;
    pb.finish();
    log::info!("all experts scheduled");

    Ok(())
}