import torch.nn as nn
from torchtitan_npu.converters.base_converter import BaseConverter
class MockImplDefault(nn.Module):
pass
class MockImplDeepseek(nn.Module):
pass
class MockImplLlama(nn.Module):
pass
class TestConverter(BaseConverter):
MODEL_IMPL = {
"deepseek": MockImplDeepseek,
"llama": MockImplLlama,
"_default": MockImplDefault,
}
@classmethod
def apply(cls, model, model_name, **kargs):
return 0
def test_get_impl_cls_match():
assert TestConverter.get_impl_cls("deepseek") == MockImplDeepseek
assert TestConverter.get_impl_cls("llama") == MockImplLlama
def test_get_impl_cls_default():
assert TestConverter.get_impl_cls("not_impl") == MockImplDefault
def test_is_compatible_supports_wildcard_models():
assert TestConverter.is_compatible(job_config=None, model_name="anything")
class RestrictedConverter(BaseConverter):
MODEL_IMPL = {"_default": MockImplDefault}
SUPPORTED_MODELS = {"llama3"}
@classmethod
def apply(cls, model, model_name, **kwargs):
return 0
def test_is_compatible_rejects_unsupported_model():
assert not RestrictedConverter.is_compatible(
job_config=None, model_name="deepseek_v3"
)