from typing import Dict, List
from argparse import Namespace
import os
import subprocess
import sys

from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.utils.logger import get_logger


class Runner:

    def __init__(self):
        system_config = get_system_config()
        self.nnodes: str = system_config.nnodes
        self.nproc_per_node: str = system_config.nproc_per_node
        self.node_rank: int = system_config.node_rank
        self.master_addr: str = system_config.master_addr
        self.master_port: int = system_config.master_port
        self._logger = get_logger("runner")

    def get_base_argv(self) -> List[str]:
        return sys.argv.copy()

    def get_base_env(self) -> Dict[str, str]:
        return os.environ.copy()

    def run(
            self,
            modified_argv: List[str],
            modified_env: Dict[str, str]
    ) -> int:
        cmd = [
                  "torchrun",
                  "--nnodes", str(self.nnodes),
                  "--nproc-per-node", str(self.nproc_per_node),
                  "--node-rank", str(self.node_rank),
                  "--master-addr", str(self.master_addr),
                  "--master-port", str(self.master_port)
              ] + modified_argv
        self._logger.debug(f"Next job command: {cmd} with env {modified_env}")

        process = subprocess.Popen(
            cmd,
            preexec_fn=os.setpgrp,
            env=modified_env
        )
        process.wait()
        return_code = process.returncode
        self._logger.info("Last job returns %d.", return_code)

        return return_code