import random
from functools import wraps
import numpy as np
import torch
import torch.utils._device
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, '__module__', None) == 'torch.nn.init':
if 'tensor' in kwargs:
return kwargs['tensor']
else:
return args[0]
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
def with_empty_init(func):
@wraps(func)
def wrapper(*args, **kwargs):
with EmptyInitOnDevice('cpu'):
return func(*args, **kwargs)
return wrapper
def culens2mask(
cu_seqlens=None,
cu_seqlens_kv=None,
max_seqlen=None,
max_seqlen_kv=None,
is_causal=False
):
assert len(cu_seqlens) == len(cu_seqlens_kv);
bsz = len(cu_seqlens) - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attn_mask = torch.zeros(bsz, max_seqlen, max_seqlen_kv, dtype=torch.bool)
for i, (seq_len, seq_len_kv) in enumerate(zip(seqlens, seqlens_kv)):
if is_causal:
attn_mask[i, :seq_len, :seq_len_kv] = torch.triu(torch.ones(seq_len, seq_len_kv), diagonal=1).bool()
else:
attn_mask[i, :seq_len, :seq_len_kv] = torch.ones([seq_len, seq_len_kv], dtype=torch.bool)
return attn_mask