"""Utilities for logging and serialization"""
import os
import random
import time
import numpy as np
import torch
import json
import subprocess
from fp16 import FP16_Optimizer
import mpu
from tensorboardX import SummaryWriter
SUMMARY_WRITER_DIR_NAME = 'runs'
def get_log_dir(name, base):
return os.path.join(base, SUMMARY_WRITER_DIR_NAME, name)
def get_sample_writer(log_dir, iteration=0):
"""Returns a tensorboard summary writer
"""
return SummaryWriter(
log_dir=log_dir, purge_step=iteration)
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def get_hostname():
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
return master_addr
def get_spare_port(args):
if torch.distributed.get_rank() == 0:
port = subprocess.check_output(["shuf -n 1 -i 10000-65535"], shell=True)
port = int(port.strip())
if port == args.master_port:
port = subprocess.check_output(["shuf -n 1 -i 10000-65535"], shell=True)
port = int(port.strip())
port = torch.cuda.LongTensor([port])
else:
port = torch.cuda.LongTensor([0])
torch.distributed.broadcast(port, 0)
port = port.item()
return port
def print_and_save_args(args, verbose=True, log_dir=None):
"""Print arguments."""
if verbose:
print('arguments:', flush=True)
for arg in vars(args):
dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
if log_dir is not None:
json_file = os.path.join(log_dir, "config.json")
with open(json_file, "w") as output:
json.dump(vars(args), output, sort_keys=True)
if args.deepspeed and args.deepspeed_config is not None:
with open(args.deepspeed_config) as file:
deepspeed_config = json.load(file)
deepspeed_json_file = os.path.join(log_dir, "config_gpt_large.json")
with open(deepspeed_json_file, "w") as output:
json.dump(deepspeed_config, output)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
class Timers:
"""Group of timers."""
class Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
if self.started_:
self.stop()
elapsed_ = self.elapsed_
if reset:
self.reset()
if started_:
self.start()
return elapsed_
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = self.Timer(name)
return self.timers[name]
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
print_rank_0(string)
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.memory_reserved() / mega_bytes)
print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
if release:
d = 'release'
else:
d = '{}'.format(iteration)
if zero:
dp_rank = mpu.get_data_parallel_rank()
d += '_zero_dp_rank_{}'.format(dp_rank)
return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
def ensure_directory_exists(filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_zero_checkpoint(args, iteration, optimizer):
zero_sd = {'iteration': iteration,
'optimizer_state_dict': optimizer.state_dict()}
zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True)
ensure_directory_exists(zero_checkpoint_name)
torch.save(zero_sd, zero_checkpoint_name)
print(' successfully saved {}'.format(zero_checkpoint_name))
def save_checkpoint(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True,
only_changed_parameters=False, no_deepspeed=False, no_save_optim=False):
"""Save a model checkpoint."""
if tag is None:
tag = str(iteration)
if args.deepspeed and not no_deepspeed:
save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
else:
if mpu.get_data_parallel_rank() == 0:
checkpoint_name = get_checkpoint_name(args.save, tag)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
sd = {'iteration': iteration}
if args.deepspeed:
model = model.module
state_dict = model.state_dict()
if only_changed_parameters:
requires_grad_dict = {}
for name, parameter in model.named_parameters():
requires_grad_dict[name] = parameter.requires_grad
state_dict = {key: value for key, value in state_dict.items() if requires_grad_dict[key]}
sd['module'] = state_dict
if not args.no_save_optim and not no_save_optim:
if optimizer is not None:
sd['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
sd['lr_scheduler'] = lr_scheduler.state_dict()
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
ensure_directory_exists(checkpoint_name)
torch.save(sd, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
if barrier:
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(tag)
def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag):
"""Save a model checkpoint."""
sd = {}
sd['iteration'] = iteration
if lr_scheduler is not None:
sd['client_lr_scheduler'] = lr_scheduler.state_dict()
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
model.save_checkpoint(args.save, tag, client_state=sd)
def get_checkpoint_iteration(load_path):
tracker_filename = get_checkpoint_tracker_filename(load_path)
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
if os.path.isdir(load_path):
path = os.path.normpath(load_path)
load_dir, tag = os.path.split(path)
print_rank_0('Try to directly load the checkpoint from the directory')
return load_dir, tag, False, True
print_rank_0(' will not load any checkpoints and will start from '
'random')
return load_path, 0, False, False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
release = metastring == 'release'
return load_path, metastring, release, True
def load_checkpoint(model, optimizer, lr_scheduler, args, no_deepspeed=False, no_load_optim=False, no_load_rng=False):
"""Load a model checkpoint."""
load_dir, tag, release, success = get_checkpoint_iteration(args.load)
if not success:
return 0
if args.deepspeed and not no_deepspeed:
checkpoint_name, sd = model.load_checkpoint(load_dir, tag,
load_optimizer_states=not args.no_load_optim and not no_load_optim,
load_lr_scheduler_states=not args.no_load_lr_scheduler)
if not args.no_load_lr_scheduler and "client_lr_scheduler" in sd:
lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
print_rank_0("Load lr scheduler state")
if checkpoint_name is None:
if mpu.get_data_parallel_rank() == 0:
print("Unable to load checkpoint.")
return tag
else:
checkpoint_name = get_checkpoint_name(load_dir, tag, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
sd = torch.load(checkpoint_name, map_location='cpu')
if args.deepspeed:
model = model.module
missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
if missing_keys or unexpected_keys:
print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")
if not release and not args.finetune and not args.no_load_optim and not no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(sd['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(sd['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
if args.finetune or release:
iteration = 0
else:
try:
iteration = sd['iteration']
except KeyError:
try:
iteration = sd['total_iters']
except KeyError:
print_rank_0('A metadata file exists but Unable to load iteration '
' from checkpoint {}, starting from 0 iteration'.format(checkpoint_name))
iteration = 0
if not release and not args.finetune and not args.no_load_rng and not no_load_rng:
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load random state from checkpoint {}, exiting. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the random '
'state.'.format(checkpoint_name))
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
def load_weights(src, dst, dst2src=False):
"""
Loads weights from src to dst via in place copy.
src is a huggingface gpt2model, while dst is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src is still untested
"""
conv_layer = 'Conv1D' in str(type(src))
for n, p in src.named_parameters():
if dst2src:
data = dst._parameters[n].data
load = p.data
else:
data = p.data
load = dst._parameters[n].data
if conv_layer and 'weight' in n:
data = data.t().contiguous()
load.copy_(data)
def load_mlp(our, oai, dst2src=False):
load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
def load_attention(our, oai, dst2src=False):
load_weights(oai.c_attn, our.query_key_value, dst2src)
load_weights(oai.c_proj, our.dense, dst2src)
def load_transformer_layer(our, oai, dst2src=False):
load_weights(oai.ln_1, our.input_layernorm, dst2src)
load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
load_mlp(our.mlp, oai.mlp, dst2src)
load_attention(our.attention, oai.attn, dst2src)
def move_weights(our, oai, dst2src=False):
"""
Loads weights from `oai` to `our` via in place copy.
`oai` is a huggingface gpt2model, while `our` is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src=True is still untested
"""
transformer_model = oai.transformer
load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src)
load_weights(transformer_model.wte, our.word_embeddings, dst2src)
load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src)
def debug_finetune_data(local_vars, batch_id, tokenizer):
tokens, target_ids = local_vars["tokens"], local_vars["target_ids"]
attention_mask, logit_mask, position_ids = local_vars["attention_mask"], local_vars["logit_mask"], local_vars[
"position_ids"]
output_tokens = []
sep = attention_mask[batch_id].item()
for i, token in enumerate(tokens[batch_id][:sep].tolist()):
token = tokenizer.IdToToken(token)
if token == '[MASK]':
token = f"[{position_ids[batch_id][0, i].item()}]"
output_tokens.append(token)
print(" ".join(output_tokens))
target_positions = []
for i in range(sep, tokens.size(-1)):
if logit_mask[batch_id][i]:
target_positions.append(i)
print(target_positions)
print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist()))
if len(target_ids.shape) > 2:
print(tokenizer.DecodeIds(target_ids[batch_id][target_positions].tolist()))
else:
print(tokenizer.DecodeIds(target_ids[batch_id].tolist()))
print(position_ids[batch_id][:, target_positions])