import torch
import triton
import triton.language as tl
import torch.nn as nn
FMIX32_C1 = 0x85EBCA6B
FMIX32_C2 = 0xC2B2AE35
from sglang.srt.layers.multimodal import _fmix32 as _fmix32_sgl
@triton.jit
def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
c1 = tl.full((), C1, tl.uint32)
c2 = tl.full((), C2, tl.uint32)
x ^= x >> 16
x = x * c1
x ^= x >> 13
x = x * c2
x ^= x >> 16
return x
@triton.jit
def fmix32_kernel(
input_ptr,
output_ptr,
n_elements,
C1: tl.constexpr,
C2: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int32)
result = _fmix32(x, C1=C1, C2=C2)
tl.store(output_ptr + offsets, result, mask=mask)
@triton.jit
def fmix32_kernel_sgl(
input_ptr,
output_ptr,
n_elements,
C1: tl.constexpr,
C2: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int32)
result = _fmix32_sgl(x, C1=C1, C2=C2)
tl.store(output_ptr + offsets, result, mask=mask)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, input_tensor, C1=None, C2=None):
if C1 is None:
C1 = FMIX32_C1
if C2 is None:
C2 = FMIX32_C2
assert input_tensor.dtype == torch.int32, "Input must be int32 (representing uint32)"
n_elements = input_tensor.numel()
output = torch.empty_like(input_tensor)
BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
fmix32_kernel[grid](
input_tensor,
output,
n_elements,
C1=C1,
C2=C2,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
class ModelSglang(nn.Module):
def __init__(self):
super(ModelSglang, self).__init__()
def forward(self, input_tensor, C1=None, C2=None):
if C1 is None:
C1 = FMIX32_C1
if C2 is None:
C2 = FMIX32_C2
assert input_tensor.dtype == torch.int32, "Input must be int32 (representing uint32)"
n_elements = input_tensor.numel()
output = torch.empty_like(input_tensor)
BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
fmix32_kernel_sgl[grid](
input_tensor,
output,
n_elements,
C1=C1,
C2=C2,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
def get_inputs():
n_elements = 1024
uint32_max = 2**32 - 1
rand_int64 = torch.randint(0, uint32_max + 1, (n_elements,), dtype=torch.int64)
input_tensor = rand_int64.to(torch.int32)
C1 = FMIX32_C1
C2 = FMIX32_C2
return [input_tensor, C1, C2]
def get_init_inputs():
return []