#!/bin/bash
#SBATCH --exclusive
#SBATCH --mem=0
#SBATCH --overcommit

# Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -eux

# The following variables variables need to be set
# Base container to be used
# Location of dataset for phase 1
#readonly DATADIR="<machine_specific_path_here>"
# Location of dataset for phase 2
#readonly DATADIR_PHASE2="<machine_specific_path_here>"
# Path to where trained checkpoints will be saved on the system
#readonly CHECKPOINTDIR="<machine_specific_path_here>"
# Path to pretrained Phase1 checkpoint
#readonly CHECKPOINTDIR_PHASE1="<machine_specific_path_here>"

# Vars without defaults
: "${DGXSYSTEM:?DGXSYSTEM not set}"
: "${CONT:?CONT not set}"

# Vars with defaults
: "${NEXP:=5}"
: "${DATESTAMP:=$(date +'%y%m%d%H%M%S%N')}"
: "${LOGDIR:=./results}"
: "${CLEAR_CACHES:=1}"

# Other vars
readonly _logfile_base="${LOGDIR}/${DATESTAMP}"
readonly _cont_name=language_model
readonly _cont_mounts="${DATADIR}:/workspace/data,${DATADIR_PHASE2}:/workspace/data_phase2,${CHECKPOINTDIR}:/results,${CHECKPOINTDIR_PHASE1}:/workspace/phase1,${EVALDIR}:/workspace/evaldata"

srun --ntasks="${SLURM_JOB_NUM_NODES}" --ntasks-per-node=1 mkdir -p "${CHECKPOINTDIR}"

# If THROUGHPUT_RUN env variable not empty, do a small number of steps to get throughput, otherwise stop based on mlm_accuracy threshold (happens before large number of steps set)
THROUGHPUT_RUN=${THROUGHPUT_RUN:-""}
if [ -z "$THROUGHPUT_RUN" ]
then
  MAX_STEPS=${MAX_STEPS:-1536000}
else
  MAX_STEPS=4
fi

PHASE2="\
    --train_batch_size=${BATCHSIZE:-10} \
    --learning_rate=${LR:-4e-3} \
    --opt_lamb_beta_1=${OPT_LAMB_BETA_1:-0.9} \
    --opt_lamb_beta_2=${OPT_LAMB_BETA_2:-0.999} \
    --warmup_proportion=${WARMUP_PROPORTION:-0.0} \
    --warmup_steps=${WARMUP_STEPS:-0.0} \
    --start_warmup_step=${START_WARMUP_STEP:-0.0} \
    --max_steps=$MAX_STEPS \
    --phase2 \
    --max_seq_length=512 \
    --max_predictions_per_seq=76 \
    --input_dir=/workspace/data_phase2 \
    --init_checkpoint=/workspace/phase1/model.ckpt-28252.pt \
    "

PHASE=${PHASE:-2}

echo "***** Running Phase $PHASE *****"
echo "***** SLURM_NNODES: $SLURM_NNODES *****"
echo "***** SLURM_NTASKS: $SLURM_NTASKS *****"

MAX_SAMPLES_TERMINATION=${MAX_SAMPLES_TERMINATION:-14000000}
EVAL_ITER_START_SAMPLES=${EVAL_ITER_START_SAMPLES:-500000}
EVAL_ITER_SAMPLES=${EVAL_ITER_SAMPLES:-500000}

GRADIENT_STEPS=${GRADIENT_STEPS:-2}
USE_DDP=${USE_DDP:-0}

# Run fixed number of training samples
BERT_CMD="\
    python -u /workspace/bert/run_pretraining.py \
    $PHASE2 \
    --do_train \
    --skip_checkpoint \
    --train_mlm_accuracy_window_size=0 \
    --target_mlm_accuracy=${TARGET_MLM_ACCURACY:-0.720} \
    --weight_decay_rate=${WEIGHT_DECAY_RATE:-0.01} \
    --max_samples_termination=${MAX_SAMPLES_TERMINATION} \
    --eval_iter_start_samples=${EVAL_ITER_START_SAMPLES} --eval_iter_samples=${EVAL_ITER_SAMPLES} \
    --eval_batch_size=16 --eval_dir=/workspace/evaldata \
    --cache_eval_data \
    --output_dir=/results \
    --fp16 --allreduce_post_accumulation --allreduce_post_accumulation_fp16 --fused_gelu_bias --dense_seq_output --fused_mha ${EXTRA_PARAMS} \
    --gradient_accumulation_steps=${GRADIENT_STEPS} --clip_and_accumulate \
    --log_freq=1 \
    --local_rank=\${SLURM_LOCALID} \
    --bert_config_path=/workspace/phase1/bert_config.json"

CONT_FILE=${CONT}

# Setup container
srun --ntasks="$(( SLURM_JOB_NUM_NODES))" --container-image="${CONT_FILE}" --container-name="${_cont_name}" true

# Run experiments
for _experiment_index in $(seq 1 "${NEXP}"); do
    (
        echo "Beginning trial ${_experiment_index} of ${NEXP}"

        # Clear caches
        if [ "${CLEAR_CACHES}" -eq 1 ]; then
            srun --ntasks="${SLURM_JOB_NUM_NODES}" bash -c "echo -n 'Clearing cache on ' && hostname && sync && sudo /sbin/sysctl vm.drop_caches=3"
            srun --ntasks="${SLURM_JOB_NUM_NODES}" --container-name="${_cont_name}" python -c "
import mlperf_logger
mlperf_logger.log_event(key=mlperf_logger.constants.CACHE_CLEAR, value=True)"
        fi

        # Run experiment
        srun -l --mpi=none --ntasks="$(( SLURM_JOB_NUM_NODES * DGXNGPU ))" --ntasks-per-node="${DGXNGPU}" --container-name="${_cont_name}" --container-mounts="${_cont_mounts}" sh -c "/workspace/bert/run_and_time.sh \"${BERT_CMD}\" ${SEED:-$RANDOM} "

    ) |& tee "${_logfile_base}_${_experiment_index}.log"
done