# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import sys
import os
import torch
import copy

checkpoint = sys.argv[1]
target_mp = int(sys.argv[2])

assert os.path.isdir(checkpoint)
iteration_file = os.path.join(checkpoint, 'latest_checkpointed_iteration.txt')
if os.path.exists(iteration_file):
    with open(iteration_file) as fin:
        iteration = int(fin.read().strip())
    checkpoint = os.path.join(checkpoint, str(iteration))
else:
    iteration = None

filenames = os.listdir(checkpoint)
filenames = [
    filename for filename in filenames if filename.startswith("mp_rank_")
]
filenames = sorted(filenames, key=lambda x: int(x.split('_')[2]))
filenames = [os.path.join(checkpoint, x) for x in filenames]

if target_mp == len(filenames):
    print("MP size keeps the same.")
    exit(0)

if sys.argv[1][-1] == '/':
    new_checkpoint = sys.argv[1][:-1] + '_MP' + sys.argv[2]
else:
    new_checkpoint = sys.argv[1] + '_MP' + sys.argv[2]
if not os.path.exists(new_checkpoint):
    os.mkdir(new_checkpoint)
if iteration is not None:
    with open(
            os.path.join(new_checkpoint, 'latest_checkpointed_iteration.txt'),
            'w') as fout:
        fout.write("{}\n".format(iteration))
    new_checkpoint = os.path.join(new_checkpoint, str(iteration))
    if not os.path.exists(new_checkpoint):
        os.mkdir(new_checkpoint)

preserve_keys = [
    "lr_scheduler",
    "skipped_steps",
    "global_steps",
    "global_samples",
    "dp_world_size",
    "iteration",
    "client_lr_scheduler",
    "np_rng_state",
    "random_rng_state",
    "torch_rng_state",
    "cuda_rng_state",
    "rng_tracker_states",
]

if target_mp < len(filenames):
    print("Decrease MP size.")
    assert len(filenames) % target_mp == 0
    ratio = len(filenames) // target_mp
    for i in range(target_mp):
        start = ratio * i
        end = ratio * (i + 1)
        d = torch.load(filenames[start], map_location='cpu')
        for k in d.keys():
            if k != 'module':
                if k in preserve_keys:
                    pass
                elif k == "mp_world_size":
                    d[k] = target_mp
                else:
                    d[k] = None
        for j in range(start + 1, end):
            d_new = torch.load(filenames[j], map_location='cpu')
            for k, v in d_new['module'].items():
                assert len(v.shape) < 3
                if len(v.shape) == 2 and 'position' not in k:
                    if 'query' in k:
                        size_1 = d['module'][k].shape[0] // 3
                        size_2 = v.shape[0] // 3
                        target = d['module'][k]
                        d['module'][k] = torch.cat([
                            target[:size_1, :], v[:size_2, :],
                            target[size_1:size_1 * 2, :],
                            v[size_2:size_2 * 2, :], target[size_1 * 2:, :],
                            v[size_2 * 2:, :]
                        ], 0)
                    elif 'word' in k or 'h_to_4h' in k or 'relative' in k or "r_w_bias" in k or "r_r_bias" in k:
                        d['module'][k] = torch.cat([d['module'][k], v], 0)
                    else:
                        d['module'][k] = torch.cat([d['module'][k], v], 1)
                elif len(v.shape) == 1 and 'query_key_value' in k:
                    size_1 = d['module'][k].shape[0] // 3
                    size_2 = v.shape[0] // 3
                    target = d['module'][k]
                    d['module'][k] = torch.cat([
                        target[:size_1], v[:size_2], target[size_1:size_1 * 2],
                        v[size_2:size_2 * 2], target[size_1 * 2:],
                        v[size_2 * 2:]
                    ], 0)
                elif len(v.shape) == 1 and ('dense_h_to_4h' in k
                                            or "attention.relative" in k):
                    d['module'][k] = torch.cat([d['module'][k], v], 0)
        filename = os.path.join(new_checkpoint,
                                "mp_rank_{:02d}_model_states.pt".format(i))
        torch.save(d, filename)

if target_mp > len(filenames):
    print("Increase MP size.")
    assert target_mp % len(filenames) == 0
    ratio = target_mp // len(filenames)
    for i in range(len(filenames)):
        start = ratio * i
        end = ratio * (i + 1)
        d = torch.load(filenames[i], map_location='cpu')
        for j in range(start, end):
            d_new = {}
            shift = j - start
            for k, v in d.items():
                if k != 'module':
                    if k in preserve_keys:
                        d_new[k] = copy.deepcopy(d[k])
                    elif k == "mp_world_size":
                        d_new[k] = target_mp
                    else:
                        d_new[k] = None
            d_new['module'] = {}
            with torch.no_grad():
                for k, v in d['module'].items():
                    assert len(v.shape) < 3
                    if len(v.shape) == 2 and 'position' not in k:
                        if 'query' in k:
                            part = v.shape[0] // ratio // 3
                            d_new['module'][k] = torch.cat([
                                v[shift * part:(shift + 1) * part, :].clone(),
                                v[(shift + ratio) * part:(shift + 1 + ratio) *
                                  part, :].clone(),
                                v[(shift + 2 * ratio) *
                                  part:(shift + 1 + 2 * ratio) *
                                  part, :].clone()
                            ], 0)
                        elif 'word' in k or 'h_to_4h' in k or 'relative' in k or "r_w_bias" in k or "r_r_bias" in k:
                            part = v.shape[0] // ratio
                            d_new['module'][k] = v[shift * part:(shift + 1) *
                                                   part, :].clone()
                        else:
                            part = v.shape[1] // ratio
                            d_new['module'][k] = v[:,
                                                   shift * part:(shift + 1) *
                                                   part].clone()
                    elif len(v.shape) == 1 and ('dense_h_to_4h' in k
                                                or "attention.relative" in k):
                        part = v.shape[0] // ratio
                        d_new['module'][k] = v[shift * part:(shift + 1) *
                                               part].clone()
                    elif len(v.shape) == 1 and 'query_key_value' in k:
                        part = v.shape[0] // ratio // 3
                        d_new['module'][k] = torch.cat([
                            v[shift * part:(shift + 1) * part].clone(),
                            v[(shift + ratio) * part:(shift + 1 + ratio) *
                              part].clone(),
                            v[(shift + 2 * ratio) *
                              part:(shift + 1 + 2 * ratio) * part].clone()
                        ], 0)
                    else:
                        d_new['module'][k] = v.clone()
            filename = os.path.join(new_checkpoint,
                                    "mp_rank_{:02d}_model_states.pt".format(j))
            torch.save(d_new, filename)