from .float8_tensor import Float8Tensor, Float8Tensor2D
from .float8_tensor_cpu import Float8TensorCpu
from .mxfp8_tensor import MXFP8Tensor
from .mxfp8_tensor_cpu import MXFP8TensorCpu
from .float8_block_tensor import Float8BlockTensor

FP8_TENSOR = (Float8Tensor, Float8Tensor2D)


def is_fp8_tensor(tensor):
    return isinstance(tensor, FP8_TENSOR)


def is_fp8_tensor_2d(tensor):
    return isinstance(tensor, Float8Tensor2D)