# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# encoding: utf-8
"""
@author:  sherlock
@contact: sherlockliao01@gmail.com
"""

import argparse
import os
import sys
import torch
if torch.__version__ >= '1.8':
    import torch_npu
from torch.backends import cudnn
import torch.distributed as dist#lmm
import torch.multiprocessing as mp

sys.path.append('.')
from config import cfg
from data import make_data_loader
from engine.trainer import do_train, do_train_with_center
from modeling import build_model
from layers import make_loss, make_loss_with_center
from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR

from utils.logger import setup_logger
import torch.distributed as dist   #add by lmm
from apex import amp

def train(rank, cfg, args):
    dist.init_process_group(                                   
    	backend='hccl',
    	world_size=args.world_size,
    	rank=rank) 
    torch.npu.set_device('npu:{}'.format(rank))
    # prepare dataset
    train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)

    # prepare model
    model = build_model(cfg, num_classes)
    state_dict = torch.load(cfg.TEST.WEIGHT)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items(): 
        if "classifier" in k:        
            continue
        name = k[7:]
        new_state_dict[name] = v
    model.state_dict().update(new_state_dict)

    if "npu" in cfg.MODEL.DEVICE:      #device
        model = model.npu()

    if cfg.MODEL.IF_WITH_CENTER == 'yes':
        print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
        loss_func, center_criterion = make_loss_with_center(cfg, num_classes)  # modified by gu
        optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
        rank = 0

        start_epoch = 0
        scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                          cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
        
        if "npu" in cfg.MODEL.DEVICE:
            model, [optimizer, optimizer_center] = amp.initialize(model, [optimizer, optimizer_center], opt_level="O2", loss_scale=args.loss_scale)
        
        do_train_with_center(
            cfg,
            model,
            center_criterion,
            train_loader,
            val_loader,
            optimizer,
            optimizer_center,
            scheduler,      # modify for using self trained model
            loss_func,
            num_query,
            start_epoch,     # add for using self trained model
            num_npus,
            rank
        )


def main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument(
        "--config_file", default="", help="path to config file", type=str
    )
    parser.add_argument('--loss_scale', default="dynamic", type=str)
    parser.add_argument('-g', '--npus', default=1, type=int,
                        help='number of gpus per node')#lmm
    parser.add_argument('-r', '--local_rank', default=0, type=int,
                        help='ranking within the npus')#lmm
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()

    args.world_size = args.npus
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '22222'

    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = setup_logger("reid_baseline", output_dir, 0)
    logger.info("Using {} NPUS".format(args.world_size))
    logger.info(args)

    if args.config_file != "":
        logger.info("Loaded configuration file {}".format(args.config_file))
        with open(args.config_file, 'r') as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    train(args.local_rank, cfg, args)


if __name__ == '__main__':
    main()