386d82fd创建于 2025年4月21日历史提交
import random
from functools import wraps
import numpy as np
import torch
import torch.utils._device


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
    def __init__(self, device=None):
        self.device = device

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if getattr(func, '__module__', None) == 'torch.nn.init':
            if 'tensor' in kwargs:
                return kwargs['tensor']
            else:
                return args[0]
        if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
            kwargs['device'] = self.device
        return func(*args, **kwargs)
    

def with_empty_init(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        with EmptyInitOnDevice('cpu'):
            return func(*args, **kwargs)
    return wrapper


def culens2mask(
    cu_seqlens=None,
    cu_seqlens_kv=None,
    max_seqlen=None,
    max_seqlen_kv=None,
    is_causal=False
):
    assert len(cu_seqlens) == len(cu_seqlens_kv); # "q k v should have same bsz..."
    bsz = len(cu_seqlens) - 1
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
    
    attn_mask = torch.zeros(bsz, max_seqlen, max_seqlen_kv, dtype=torch.bool)
    for i, (seq_len, seq_len_kv) in enumerate(zip(seqlens, seqlens_kv)):
        if is_causal:
            attn_mask[i, :seq_len, :seq_len_kv] = torch.triu(torch.ones(seq_len, seq_len_kv), diagonal=1).bool()
        else:
            attn_mask[i, :seq_len, :seq_len_kv] = torch.ones([seq_len, seq_len_kv], dtype=torch.bool)

    return attn_mask