import os
from typing import Optional
from megatron.training import get_args
try:
from megatron.training.microbatches import build_num_microbatches_calculator, NumMicroBatchesCalculator
except ImportError:
try:
from megatron.microbatches import build_num_microbatches_calculator, NumMicroBatchesCalculator
except ImportError:
try:
from megatron.core.num_microbatches_calculator import build_num_microbatches_calculator, \
NumMicroBatchesCalculator
except ImportError:
class NumMicroBatchesCalculator:
def __init__(self, args):
self.args = args
def get(self):
return self.args.micro_batch_size
def get_current_global_batch_size(self):
return self.args.global_batch_size
def update(self, consumed_samples, consistency_check=True):
pass
def build_num_microbatches_calculator(args):
return NumMicroBatchesCalculator(args)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check)
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
modellink_version = os.getenv('ML_VERSION', "1.1")
if modellink_version == "2.0.0":
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = (
build_num_microbatches_calculator(args.rank, args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size))
else:
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(args)
def rebuild_num_microbatches_calculator():
args = get_args()
_build_num_microbatches_calculator(args)