import torch
from tensor_cast.layers.quant_linear import QuantLinearBase
from tensor_cast.model_config import ParallelConfig
from tensor_cast.quantize_utils import LinearQuantType
from tests.regression.tensor_cast.test_common import get_linear_quant_config
def test_tensor_cast_parallel_layer_smoke():
cfg = ParallelConfig(world_size=4, tensor_parallel_size=2, data_parallel_size=2)
assert cfg.has_attn_tp()
assert cfg.data_parallel_size == 2
linear = torch.nn.Linear(8, 4, bias=False, dtype=torch.float32)
quant_layer = QuantLinearBase(
linear,
get_linear_quant_config(LinearQuantType.W8A16, linear.weight.data),
)
x = torch.randn(2, 8, dtype=torch.float32)
y = quant_layer(x)
assert y.shape == (2, 4)