#!/usr/bin/env python3



# Copyright (c) Facebook, Inc. and its affiliates.

#

# This source code is licensed under the MIT license found in the

# LICENSE file in the root directory of this source tree.



# Copyright 2020 Huawei Technologies Co., Ltd

#

# Licensed under the BSD 3-Clause License (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

# https://spdx.org/licenses/BSD-3-Clause.html

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.







"""Train a classification model."""

import sys

sys.path.append("./")

import pycls.core.config as config

import pycls.core.distributed as dist

import pycls.core.trainer as trainer

from pycls.core.config import cfg

import argparse,sys,os,torch

import torch

if torch.__version__ >= '1.8':

    import torch_npu



def init_process_group(proc_rank, world_size, device_type="npu", port="29588"):

    """Initializes the default process group."""



    # Initialize the process group

    print("==================================")    

    print('Begin init_process_group')

    os.environ['MASTER_ADDR'] = '127.0.0.1'

    os.environ['MASTER_PORT'] = port

    if device_type == "npu":

        torch.distributed.init_process_group(

            backend=cfg.DIST_BACKEND,

            world_size=world_size,

            rank=proc_rank

        )

    elif device_type == "gpu":

        torch.distributed.init_process_group(

            backend=cfg.DIST_BACKEND,

            init_method="tcp://{}:{}".format("127.0.0.1", port),

            world_size=world_size,

            rank=proc_rank

        )        



    print("==================================")

    print("Done init_process_group")



    # Set the GPU to use

    #torch.cuda.set_device(proc_rank)

    if device_type == "npu":

        torch.npu.set_device(proc_rank)

    elif device_type == "gpu":

        torch.cuda.set_device(proc_rank)

    print('Done set device', device_type, cfg.DIST_BACKEND, world_size, proc_rank)



def main():

    """Load config from command line arguments and set any specified options."""

    parser = argparse.ArgumentParser(description="Config file options.")

    parser.add_argument("--device", help="gpu or npu", default="npu", type=str)

    parser.add_argument("--profperf", help="0 or 1", default=0, type=int)

    help_s = "Config file location"

    parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)

    parser.add_argument("--rank_id", dest="rank_id", default=0, type=int)

    parser.add_argument("--device_id", dest="device_id", default=0, type=int)

    help_s = "See pycls/core/config.py for all options"

    parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)

    if len(sys.argv) == 1:

        parser.print_help()

        sys.exit(1)

    args = parser.parse_args()

    print(args)



    config.merge_from_file(args.cfg_file)

    config._C.merge_from_list(args.opts)

    config.assert_and_infer_cfg()

    cfg.freeze()



    if cfg.NUM_GPUS > 1:

        init_process_group(proc_rank=args.rank_id, world_size=cfg.NUM_GPUS, device_type=args.device)

    elif args.device == "npu":

        torch.npu.set_device(args.device_id)

    elif args.device == "gpu":

        torch.cuda.set_device(args.device_id)



    if args.device == "npu":  

        cur_device = torch.npu.current_device()

    elif args.device == "gpu":

        cur_device = torch.cuda.current_device()

    print('cur_device: ', cur_device)



    trainer.train_model(args.device, args.profperf)





if __name__ == "__main__":

    main()