"""AnnealingLR LR Schedule."""
import math
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
from mindspore.ops import operations as P
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
__all__ = ['AnnealingLR']
@MindFormerRegister.register(MindFormerModuleType.LR)
class AnnealingLR(LearningRateSchedule):
""" AnnealingLR implementation for visualglm """
DECAY_STYLES = ["linear", "cosine", "exponential", "constant", "None"]
def __init__(self, learning_rate, warmup_steps, num_iters, total_steps, decay_style="cosine", last_iter=-1,
decay_ratio=0.1, auto_warmup_steps=100, auto_warmup_rate=0.05):
super(AnnealingLR, self).__init__()
self.total_steps = total_steps
self.start_lr = learning_rate
self.warmup_iter = Tensor(warmup_steps, mstype.float32)
self.init_step = last_iter
self.num_iters = Tensor(last_iter + 1, mstype.float32)
self.end_iter = Tensor(num_iters, mstype.float32)
self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
self.decay_ratio = 1 / decay_ratio
self.auto_warmup_steps = auto_warmup_steps
self.auto_warmup_rate = auto_warmup_rate
self.cos = P.Cos()
self.min = P.Minimum()
def construct(self, global_step):
""" method entrance """
if global_step <= self.init_step + self.auto_warmup_steps:
auto_lr = float(self.start_lr) * self.auto_warmup_rate
schedule_lr = float(self.start_lr) * global_step / self.warmup_iter
return self.min(auto_lr, schedule_lr)
if self.warmup_iter > 0 and global_step <= self.warmup_iter:
return float(self.start_lr) * global_step / self.warmup_iter
if self.decay_style == self.DECAY_STYLES[0]:
return self.start_lr * ((self.end_iter - (global_step - self.warmup_iter)) / self.end_iter)
if self.decay_style == self.DECAY_STYLES[1]:
tmp_decay_step_ratio = (global_step - self.warmup_iter) / self.end_iter
decay_step_ratio = self.min(1.0, tmp_decay_step_ratio)
return self.start_lr / self.decay_ratio * (
(self.cos(math.pi * decay_step_ratio) + 1) * (self.decay_ratio - 1) / 2 + 1)
if self.decay_style == self.DECAY_STYLES[2]:
return self.start_lr
return self.start_lr