# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#           http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sklearn
import argparse
import copy
import os
import random
import time
import npu_fused_adamw

try:
    import nvidia_dlprof_pytorch_nvtx as pyprof
except ModuleNotFoundError:
    import pyprof

import torch
import numpy as np
import torch.cuda.profiler as profiler
import torch.distributed as dist
from apex import amp
#from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DistributedDataParallel
from common import helpers
# from common.dali.data_loader import DaliDataLoader
from common.dataset import AudioDataset, get_data_loader
from common.features import BaseFeatures, FilterbankFeatures
from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
                            process_evaluation_epoch)
from common.optimizers import AdamW, lr_policy, Novograd
from common.tb_dllogger import flush_log, init_log, log
from common.utils import BenchmarkStats
from jasper import config
from jasper.model import CTCLossNM, GreedyCTCDecoder, Jasper


def parse_args():
    parser = argparse.ArgumentParser(description='Jasper')

    training = parser.add_argument_group('training setup')
    training.add_argument('--epochs', default=400, type=int,
                          help='Number of epochs for the entire training; influences the lr schedule')
    training.add_argument("--warmup_epochs", default=0, type=int,
                          help='Initial epochs of increasing learning rate')
    training.add_argument("--hold_epochs", default=0, type=int,
                          help='Constant max learning rate epochs after warmup')
    training.add_argument('--epochs_this_job', default=0, type=int,
                          help=('Run for a number of epochs with no effect on the lr schedule.'
                                'Useful for re-starting the training.'))
    training.add_argument('--cudnn_benchmark', action='store_true', default=True,
                          help='Enable cudnn benchmark')
    training.add_argument('--amp', '--fp16', action='store_true', default=True,
                          help='Use mixed precision training')
    training.add_argument('--seed', default=42, type=int, help='Random seed')
    training.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0),
                          type=int, help='GPU id used for distributed training')
    training.add_argument('--pre_allocate_range', default=None, type=int, nargs=2,
                          help='Warmup with batches of length [min, max] before training')
    training.add_argument('--pyprof', action='store_true', help='Enable pyprof profiling')

    optim = parser.add_argument_group('optimization setup')
    optim.add_argument('--batch_size', default=32, type=int,
                       help='Global batch size')
    optim.add_argument('--lr', default=1e-3, type=float,
                       help='Peak learning rate')
    optim.add_argument("--min_lr", default=1e-5, type=float,
                       help='minimum learning rate')
    optim.add_argument("--lr_policy", default='exponential', type=str,
                       choices=['exponential', 'legacy'], help='lr scheduler')
    optim.add_argument("--lr_exp_gamma", default=0.99, type=float,
                       help='gamma factor for exponential lr scheduler')
    optim.add_argument('--weight_decay', default=1e-3, type=float,
                       help='Weight decay for the optimizer')
    optim.add_argument('--grad_accumulation_steps', default=1, type=int,
                       help='Number of accumulation steps')
    optim.add_argument('--optimizer', default='novograd', type=str,
                       choices=['novograd', 'adamw'], help='Optimization algorithm')
    optim.add_argument('--ema', type=float, default=0.0,
                       help='Discount factor for exp averaging of model weights')

    io = parser.add_argument_group('feature and checkpointing setup')
    io.add_argument('--dali_device', type=str, choices=['none', 'cpu', 'gpu'],
                    default='none', help='Use DALI pipeline for fast data processing')
    io.add_argument('--resume', action='store_true',
                    help='Try to resume from last saved checkpoint.')
    io.add_argument('--ckpt', default=None, type=str,
                    help='Path to a checkpoint for resuming training')
    io.add_argument('--save_frequency', default=10, type=int,
                    help='Checkpoint saving frequency in epochs')
    io.add_argument('--keep_milestones', default=[100, 200, 300], type=int, nargs='+',
                    help='Milestone checkpoints to keep from removing')
    io.add_argument('--save_best_from', default=380, type=int,
                    help='Epoch on which to begin tracking best checkpoint (dev WER)')
    io.add_argument('--eval_frequency', default=200, type=int,
                    help='Number of steps between evaluations on dev set')
    io.add_argument('--log_frequency', default=25, type=int,
                    help='Number of steps between printing training stats')
    io.add_argument('--prediction_frequency', default=100, type=int,
                    help='Number of steps between printing sample decodings')
    io.add_argument('--model_config', type=str, required=True,
                    help='Path of the model configuration file')
    io.add_argument('--train_manifests', type=str, required=True, nargs='+',
                    help='Paths of the training dataset manifest file')
    io.add_argument('--val_manifests', type=str, required=True, nargs='+',
                    help='Paths of the evaluation datasets manifest files')
    io.add_argument('--dataset_dir', required=True, type=str,
                    help='Root dir of dataset')
    io.add_argument('--output_dir', type=str, required=True,
                    help='Directory for logs and checkpoints')
    io.add_argument('--log_file', type=str, default=None,
                    help='Path to save the training logfile.')
    io.add_argument('--benchmark_epochs_num', type=int, default=1,
                    help='Number of epochs accounted in final average throughput.')
    io.add_argument('--override_config', type=str, action='append',
                    help='Overrides a value from a config .yaml.'
                         ' Syntax: `--override_config nested.config.key=val`.')
    return parser.parse_args()


def reduce_tensor(tensor, num_gpus):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt.true_divide(num_gpus)


def apply_ema(model, ema_model, decay):
    if not decay:
        return

    sd = getattr(model, 'module', model).state_dict()
    for k, v in ema_model.state_dict().items():
        v.copy_(decay * v + (1 - decay) * sd[k])


@torch.no_grad()
def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
             ema_model, ctc_loss, greedy_decoder, use_amp, use_dali=False):

    for model, subset in [(model, 'dev'), (ema_model, 'dev_ema')]:
        if model is None:
            continue

        model.eval()
        start_time = time.time()
        agg = {'losses': [], 'preds': [], 'txts': []}

        for batch in val_loader:
            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if val_feat_proc is not None:
                    feat, feat_lens = val_feat_proc(feat, feat_lens, use_amp)
            else:
                audio, audio_lens, txt, txt_lens = batch
                # print("audio",audio)
                feat, feat_lens = val_feat_proc(audio, audio_lens, use_amp)
            feat = feat.npu()
            audio = audio.npu()
            feat_lens = feat_lens.npu()
            txt = txt.npu()
            log_probs, enc_lens = model.forward(feat, feat_lens)
            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            pred = greedy_decoder(log_probs)

            agg['losses'] += helpers.gather_losses([loss])
            agg['preds'] += helpers.gather_predictions([pred], labels)
            agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], labels)

        wer, loss = process_evaluation_epoch(agg)
        log((epoch,), step, subset, {'loss': loss, 'wer': 100.0 * wer,
                                     'took': time.time() - start_time})
        model.train()
    return wer


def main():
    args = parse_args()

    assert(torch.npu.is_available())
    assert args.prediction_frequency % args.log_frequency == 0

    # torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = True
    if multi_gpu:
        torch.npu.set_device(args.local_rank)
        dist.init_process_group(backend='hccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    random.seed(args.seed + args.local_rank)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    assert args.grad_accumulation_steps >= 1
    assert args.batch_size % args.grad_accumulation_steps == 0
    batch_size = args.batch_size // args.grad_accumulation_steps

    print_once('Setting up datasets...')
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')

    # use_dali = args.dali_device in ('cpu', 'gpu')
    use_dali = False
    if use_dali:
        assert train_dataset_kw['ignore_offline_speed_perturbation'], \
            "DALI doesn't support offline speed perturbation"

        # pad_to_max_duration is not supported by DALI - have simple padders
        if train_features_kw['pad_to_max_duration']:
            train_feat_proc = BaseFeatures(
                pad_align=train_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=train_features_kw['max_duration'],
                sample_rate=train_features_kw['sample_rate'],
                window_size=train_features_kw['window_size'],
                window_stride=train_features_kw['window_stride'])
            train_features_kw['pad_to_max_duration'] = False
        else:
            train_feat_proc = None

        if val_features_kw['pad_to_max_duration']:
            val_feat_proc = BaseFeatures(
                pad_align=val_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=val_features_kw['max_duration'],
                sample_rate=val_features_kw['sample_rate'],
                window_size=val_features_kw['window_size'],
                window_stride=val_features_kw['window_stride'])
            val_features_kw['pad_to_max_duration'] = False
        else:
            val_feat_proc = None

        train_loader = DaliDataLoader(gpu_id=args.local_rank,
                                      dataset_path=args.dataset_dir,
                                      config_data=train_dataset_kw,
                                      config_features=train_features_kw,
                                      json_names=args.train_manifests,
                                      batch_size=batch_size,
                                      grad_accumulation_steps=args.grad_accumulation_steps,
                                      pipeline_type="train",
                                      device_type=args.dali_device,
                                      symbols=symbols)

        val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=batch_size,
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    symbols=symbols)
    else:
        train_dataset_kw, train_features_kw = config.input(cfg, 'train')
        train_dataset = AudioDataset(args.dataset_dir,
                                     args.train_manifests,
                                     symbols,
                                     **train_dataset_kw)
        train_loader = get_data_loader(train_dataset,
                                       batch_size,
                                       multi_gpu=multi_gpu,
                                       shuffle=True,
                                       num_workers=4)
        train_feat_proc = FilterbankFeatures(**train_features_kw)

        val_dataset_kw, val_features_kw = config.input(cfg, 'val')
        val_dataset = AudioDataset(args.dataset_dir,
                                   args.val_manifests,
                                   symbols,
                                   **val_dataset_kw)
        val_loader = get_data_loader(val_dataset,
                                     batch_size,
                                     multi_gpu=multi_gpu,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)
        val_feat_proc = FilterbankFeatures(**val_features_kw)

        dur = train_dataset.duration / 3600
        dur_f = train_dataset.duration_filtered / 3600
        nsampl = len(train_dataset)
        print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
                   f'filtered {dur_f:.1f}h)')

    # if train_feat_proc is not None:
    #     train_feat_proc.cpu()
    # if val_feat_proc is not None:
    #     val_feat_proc.cpu()
    train_feat_proc.cpu()
    val_feat_proc.cpu()
    steps_per_epoch = len(train_loader) // args.grad_accumulation_steps

    # set up the model
    model = Jasper(encoder_kw=config.encoder(cfg),
                   decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
    model.npu()
    ctc_loss = CTCLossNM(n_classes=len(symbols))
    greedy_decoder = GreedyCTCDecoder()

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    # optimization
    kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
    if args.optimizer == "novograd":
        optimizer = Novograd(model.parameters(), **kw)
    elif args.optimizer == "adamw":
        optimizer = npu_fused_adamw(model.parameters(), **kw)
    else:
        raise ValueError(f'Invalid optimizer "{args.optimizer}"')

    adjust_lr = lambda step, epoch, optimizer: lr_policy(
        step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
        num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr,
        exp_gamma=args.lr_exp_gamma)

    if args.amp:
#        model, optimizer = amp.initialize(
#            min_loss_scale=1.0, models=model, optimizers=optimizer,
#            opt_level='O1', max_loss_scale=512.0)
         model, optimizer = amp.initialize(models=model, optimizers=optimizer, loss_scale=32, combine_grad=True)

    if args.ema > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if multi_gpu:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)

    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'Jasper',
                                args.keep_milestones, args.amp)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    # training loop
    model.train()

    # pre-allocate
    if args.pre_allocate_range is not None:
        n_feats = train_features_kw['n_filt']
        pad_align = train_features_kw['pad_align']
        a, b = args.pre_allocate_range
        for n_frames in range(a, b + pad_align, pad_align):
            print_once(f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')

            feat = torch.randn(batch_size, n_feats, n_frames, device='cpu')
            feat_lens = torch.ones(batch_size, device='cpu').fill_(n_frames)
            txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100),
                                device='cpu')
            txt_lens = torch.ones(batch_size, device='cpu').fill_(100)
            log_probs, enc_lens = model(feat, feat_lens)
            del feat
            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss.backward()
            model.zero_grad()

    bmark_stats = BenchmarkStats()

    for epoch in range(start_epoch + 1, args.epochs + 1):
        if multi_gpu and not use_dali:
            train_loader.sampler.set_epoch(epoch)

        epoch_utts = 0
        epoch_loss = 0
        accumulated_batches = 0
        epoch_start_time = time.time()

        for batch in train_loader:

            if accumulated_batches == 0:
                adjust_lr(step, epoch, optimizer)
                optimizer.zero_grad()
                step_loss = 0
                step_utts = 0
                step_start_time = time.time()

            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if train_feat_proc is not None:
                    feat, feat_lens = train_feat_proc(feat, feat_lens, args.amp)
            else:
                # print("progress")
                # batch = [t.cpu(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                # print("audio",audio)
                feat, feat_lens = train_feat_proc(audio, audio_lens, args.amp)
                # print("feat",feat)
                # print("feat_len",feat_lens)
                feat = feat.npu()
                audio = audio.npu()
                feat_lens = feat_lens.npu()
                txt = txt.npu()
            log_probs, enc_lens = model(feat, feat_lens)

            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss /= args.grad_accumulation_steps

            if torch.isnan(loss).any():
                print_once(f'WARNING: loss is NaN; skipping update')
            else:
                if multi_gpu:
                    step_loss += reduce_tensor(loss.data, world_size).item()
                else:
                    step_loss += loss.item()

                if args.amp:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                step_utts += batch[0].size(0) * world_size
                epoch_utts += batch[0].size(0) * world_size
                accumulated_batches += 1

            if accumulated_batches % args.grad_accumulation_steps == 0:
                epoch_loss += step_loss
                optimizer.step()
                apply_ema(model, ema_model, args.ema)

                if step % args.log_frequency == 0:
                    preds = greedy_decoder(log_probs)
                    wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols)

                    if step % args.prediction_frequency == 0:
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')

                    step_time = time.time() - step_start_time
                    log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
                        step, 'train',
                        {'loss': step_loss,
                         'wer': 100.0 * wer,
                         'throughput': step_utts / step_time,
                         'took': step_time,
                         'lrate': optimizer.param_groups[0]['lr']})

                step_start_time = time.time()

                if step % args.eval_frequency == 0:
                    wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                   symbols, model, ema_model, ctc_loss,
                                   greedy_decoder, args.amp, use_dali)

                    if wer < best_wer and epoch >= args.save_best_from:
                        checkpointer.save(model, ema_model, optimizer, epoch,
                                          step, best_wer, is_best=True)
                        best_wer = wer

                step += 1
                accumulated_batches = 0
                # end of step

            # DALI iterator need to be exhausted;
            # if not using DALI, simulate drop_last=True with grad accumulation
            if not use_dali and step > steps_per_epoch * epoch:
                break

        epoch_time = time.time() - epoch_start_time
        epoch_loss /= steps_per_epoch
        log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
                                          'took': epoch_time,
                                          'loss': epoch_loss})
        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)

        if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
            checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)

        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                 ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)

        checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
    flush_log()


if __name__ == "__main__":
    main()