from typing import List, Optional
from torch.optim.lr_scheduler import LRScheduler
class LULoRALRScheduler:
"""
LULoRALRScheduler class implementation.
"""
backprop_scheduler: LRScheduler
lu_lora_scheduler: LRScheduler
def __init__(
self,
backprop_scheduler: LRScheduler,
lu_lora_scheduler: LRScheduler,
) -> None:
"""
Initialize an instance of LULoRALRScheduler.
"""
self._lu_lora_scheduler = lu_lora_scheduler
self._backprop_scheduler = backprop_scheduler
@property
def lu_lora_scheduler(self) -> LRScheduler:
"""
Property to get LU-LoRA learning rate scheduler.
Returns:
LRScheduler: LU-LoRA learning rate scheduler.
"""
return self._lu_lora_scheduler
def step(self, increment: Optional[int] = None) -> None:
"""
Step for learning rate shedulers.
Args:
increment (Optional[int]): A step value for Megatron scheduler.
"""
self._backprop_scheduler.step(increment=increment)
self._lu_lora_scheduler.step()
def get_last_lr(self) -> List[float]:
"""
Return the last computed learning rate by the backpropagation scheduler.
Returns:
List[float]: Current learning rates.
"""
return self._backprop_scheduler.get_last_lr()
def state_dict(self) -> dict:
"""
Returns the state of the backpropagation scheduler.
Returns:
dict: Current state.
"""
return self._backprop_scheduler.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
"""
Loads the state of backpropagation scheduler.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
return self._backprop_scheduler.load_state_dict(state_dict)