import math
from typing import TYPE_CHECKING, Dict, Literal
import logging
from torch.optim.lr_scheduler import LambdaLR
if TYPE_CHECKING:
from torch.optim import Optimizer
logger = logging.getLogger(__name__)
class MultiLRScheduler(dict):
"""
Mapping of name -> LR scheduler with convenience methods to mirror a single scheduler.
"""
_is_multi_lr_scheduler: bool = True
def step(self) -> None:
for sched in self.values():
sched.step()
def state_dict(self) -> Dict[str, any]:
return {name: sched.state_dict() for name, sched in self.items()}
def load_state_dict(self, state_dict: Dict[str, any]) -> None:
for name, sched in self.items():
if name in state_dict:
sched.load_state_dict(state_dict[name])
def get_last_lr(self):
if not self:
return [0.0]
first = next(iter(self.values()))
return first.get_last_lr()
def build_lr_scheduler(
optimizer: "Optimizer",
train_steps: int,
lr: float = 1e-3,
lr_decay_style: Literal["constant", "linear", "cosine"] = "constant",
lr_decay_ratio: float = 1.0,
lr_warmup_ratio: float = 0.0,
lr_min: float = 1e-7,
lr_start: float = 0.0,
):
if hasattr(optimizer, "_is_multi_optimizer") or isinstance(optimizer, dict):
schedulers = {}
for key_name in optimizer.key_names:
schedulers[key_name] = build_lr_scheduler(
optimizer=optimizer.optimizers_dict[key_name],
train_steps=train_steps,
lr=lr,
lr_decay_style=lr_decay_style,
lr_decay_ratio=lr_decay_ratio,
lr_warmup_ratio=lr_warmup_ratio,
lr_min=lr_min,
lr_start=lr_start,
)
return MultiLRScheduler(schedulers)
lr_warmup_steps = int(train_steps * lr_warmup_ratio)
if lr_decay_style == "constant":
return get_constant_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
lr_start=lr_start,
init_lr=lr,
)
if lr_decay_style == "linear":
return get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=train_steps,
init_lr=lr,
lr_start=lr_start,
)
if lr_decay_style == "cosine":
return get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=train_steps,
init_lr=lr,
lr_decay_ratio=lr_decay_ratio,
min_lr=lr_min,
lr_start=lr_start,
)
raise ValueError(f"Unknown learning rate decay style: {lr_decay_style}.")
def get_constant_schedule_with_warmup(
optimizer: "Optimizer",
num_warmup_steps: int,
init_lr: float,
last_epoch: int = -1,
lr_start: float = 0.0,
):
"""
Creates a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer.
"""
def _lr_lambda(current_step: int):
if init_lr == 0:
return 0.0
if current_step < num_warmup_steps:
return (lr_start + (init_lr - lr_start) * current_step / max(1, num_warmup_steps)) / init_lr
return 1.0
return LambdaLR(optimizer, _lr_lambda, last_epoch=last_epoch)
def get_linear_schedule_with_warmup(
optimizer: "Optimizer",
num_warmup_steps: int,
num_training_steps: int,
init_lr: float,
last_epoch: int = -1,
min_lr: float = 1e-7,
lr_start: float = 0.0,
):
"""
Creates a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
"""
def _lr_lambda(current_step: int):
if init_lr == 0:
return 0.0
if current_step < num_warmup_steps:
return (lr_start + (init_lr - lr_start) * current_step / max(1, num_warmup_steps)) / init_lr
min_lr_ratio = min_lr / init_lr
return max(
min_lr_ratio,
float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)),
)
return LambdaLR(optimizer, _lr_lambda, last_epoch)
def get_cosine_schedule_with_warmup(
optimizer: "Optimizer",
num_warmup_steps: int,
num_training_steps: int,
init_lr: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
lr_decay_ratio: float = 1.0,
min_lr: float = 1e-7,
lr_start: float = 0.0,
):
"""
Creates a schedule with a learning rate that decreases following the values of the cosine function between
the initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0
and the initial lr set in the optimizer.
"""
def lr_lambda(current_step: int):
if init_lr == 0:
return 0.0
lr_decay_steps = int(num_training_steps * lr_decay_ratio)
if current_step < num_warmup_steps:
return (lr_start + (init_lr - lr_start) * current_step / max(1, num_warmup_steps)) / init_lr
min_lr_ratio = min_lr / init_lr
if current_step > lr_decay_steps:
return min_lr_ratio
progress = float(current_step - num_warmup_steps) / float(max(1, lr_decay_steps - num_warmup_steps))
if progress < 0 or progress > 1:
raise ValueError(
f"Progress value {progress:.6f} is out of range [0, 1]. "
f"Current step: {current_step}, "
f"Warmup steps: {num_warmup_steps}, "
f"Decay steps: {lr_decay_steps}"
)
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_ratio) + min_lr_ratio
return max(0, factor)
return LambdaLR(optimizer, lr_lambda, last_epoch)