import torch
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
if self.weight.device.type == 'cpu':
self.weight.data = self.weight.data.to(device=torch.cuda.current_device())
router_dtype = input.dtype
if self.config.moe_router_dtype == 'fp32':
router_dtype = torch.float32
elif self.config.moe_router_dtype == 'fp64':
router_dtype = torch.float64
logits = torch.nn.functional.linear(input.to(router_dtype), self.weight)
return logits