from typing import Callable, List, Optional, Union
import torch
from megatron.core.parallel_state import get_tensor_model_parallel_group
from megatron.training import get_args
from mindspeed_llm.tasks.posttrain.lu_lora.rules import OjasRule
def _reduce_all_async(input_: torch.Tensor):
"""
Reduce-scatter the input tensor across the model parallel group async.
Args:
input_ (torch.Tensor): Input.
Returns:
Async work handle.
"""
return torch.distributed.all_reduce(
input_,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
async_op=True
)
def get_learning_rate_provider(
last_lr_provider: Callable[[], Union[List[float], float]],
element: int,
coef: float = 1.
) -> Callable[[], float]:
"""
Wrapped _get_learning_rates
Args:
last_lr_provider (Callable[[], Union[List[float], float]]): last_lr_provider function.
element (int): Index of learning rate group.
coef (float): Coefficient for lr adaptation.
Returns:
Callable[[], float]: Wrapped _get_learning_rates function
"""
def _get_learning_rates() -> float:
learning_rates = last_lr_provider()
return learning_rates[element] / coef
return _get_learning_rates
class LULoRALayerOptimizer:
"""
LULoRALayerOptimizer class implementation.
"""
_rule: OjasRule
_get_lr: Optional[Callable[[], float]]
_accumulation_step: int
_current_step: int
_delta: torch.Tensor
def __init__(self) -> None:
"""
Initialize an instance of LULoRALayerOptimizer.
"""
args = get_args()
self._rule = OjasRule()
self._get_lr = None
self._accumulation_step = args.global_batch_size // args.micro_batch_size // args.data_parallel_size
self._current_step = 0
self._delta = 0
def set_lr_provider(self, lr_provider: Callable[[], float]) -> None:
"""
Set the learning rate provider.
Args:
lr_provider (Callable[[], float]): A function provides the current learning rate.
"""
self._get_lr = lr_provider
@torch.no_grad()
def update(self, x: torch.Tensor, y: torch.Tensor, weight: torch.nn.parameter.Parameter) -> None:
"""
Update weights.
Args:
x (torch.Tensor): Input for the Linear layer with weights W.
y (torch.Tensor): Result of the forward the Linear layer: y=xW^t.
weight (torch.nn.parameter.Parameter): Weights of the Linear layer.
"""
self._current_step += 1
if not x.numel():
return
if x.dim() == 2:
x = x.unsqueeze(1)
y = y.unsqueeze(1)
dtype = x.dtype if x.dtype == torch.bfloat16 else torch.float32
weight_data = weight.to(dtype)
delta = self._rule(x.to(dtype), y.to(dtype), weight_data)
handle = _reduce_all_async(delta)
handle.wait()
self._delta += delta
if self._current_step % self._accumulation_step != 0:
return
weight.data = (
weight_data - self._get_lr() * self._delta / self._accumulation_step
).to(weight.dtype)
torch.nn.init.zeros_(self._delta)