import contextlib
import logging
import threading
import torch
from .performance_model.utils import run_once
logger = logging.getLogger(__name__)
_meta_autocast_enabled = threading.local()
def _is_meta_autocast_enabled():
return getattr(_meta_autocast_enabled, "value", False)
def _set_meta_autocast_enabled(enabled):
_meta_autocast_enabled.value = enabled
_in_support_autocast_for_meta = False
_DTYPE_TO_TYPE_ORIGINAL_ATTR = "_tensor_cast_original_dtype_to_type"
@contextlib.contextmanager
def support_autocast_for_meta():
"""
PyTorch doesn't support "meta" device for autocast. This prevents us from
running a PyTorch model that calls autocast, such as those models in the
Transformers. We have to patch PyTorch to work it around.
"""
import torch.fx.experimental._config as config
meta_nonzero_assume_all_nonzero_orig = config.meta_nonzero_assume_all_nonzero
get_autocast_dtype_orig = torch.get_autocast_dtype
_C_is_autocast_available_orig = torch._C._is_autocast_available
is_autocast_enabled_orig = torch.is_autocast_enabled
set_autocast_enabled_orig = torch.set_autocast_enabled
def get_autocast_dtype(device):
if device == "meta":
return torch.half
else:
return get_autocast_dtype_orig(device)
def is_autocast_available(device):
if device == "meta":
return True
else:
return _C_is_autocast_available_orig(device)
def is_autocast_enabled(device):
if device == "meta":
return _is_meta_autocast_enabled()
else:
return is_autocast_enabled_orig(device)
def set_autocast_enabled(device, enabled):
if device == "meta":
_set_meta_autocast_enabled(enabled)
else:
set_autocast_enabled_orig(device, enabled)
global _in_support_autocast_for_meta
if _in_support_autocast_for_meta:
yield
return
_in_support_autocast_for_meta = True
torch.get_autocast_dtype = get_autocast_dtype
torch.is_autocast_enabled = is_autocast_enabled
torch.set_autocast_enabled = set_autocast_enabled
torch._C._is_autocast_available = is_autocast_available
config.meta_nonzero_assume_all_nonzero = True
try:
yield
finally:
torch._C._is_autocast_available = _C_is_autocast_available_orig
torch.set_autocast_enabled = set_autocast_enabled_orig
torch.is_autocast_enabled = is_autocast_enabled_orig
torch.get_autocast_dtype = get_autocast_dtype_orig
_in_support_autocast_for_meta = False
config.meta_nonzero_assume_all_nonzero = meta_nonzero_assume_all_nonzero_orig
@contextlib.contextmanager
def specialize_float():
"""
Patch torch._dynamo.config.specialize_float to True, so that the float dtype
information can be preserved in the graph. We assume floats are specialized
in our pattern matching passes like RMSNorm for params like eps.
"""
old_flag = torch._dynamo.config.specialize_float
torch._dynamo.config.specialize_float = True
try:
yield
finally:
torch._dynamo.config.specialize_float = old_flag
@contextlib.contextmanager
def patch_fallback_node_due_to_unsupported_type():
"""
Patch torch._inductor.pattern_matcher.fallback_node_due_to_unsupported_type to always return False,
so that the pattern matching passes can be applied without being blocked due to meta tensors.
"""
import torch._inductor.pattern_matcher as pattern_matcher
if not hasattr(pattern_matcher, "fallback_node_due_to_unsupported_type"):
yield
return
original_func = pattern_matcher.fallback_node_due_to_unsupported_type
def always_false(*args, **kwargs):
return False
pattern_matcher.fallback_node_due_to_unsupported_type = always_false
try:
yield
finally:
pattern_matcher.fallback_node_due_to_unsupported_type = original_func
@contextlib.contextmanager
def prepare_freezing():
"""
Prepare PyTorch Dynamo for graph freezing by enabling the relevant config.
We need this for the `freeze()` call from inductor to work properly.
"""
old_flag = torch._dynamo.config.prepare_freezing
torch._dynamo.config.prepare_freezing = True
try:
yield
finally:
torch._dynamo.config.prepare_freezing = old_flag
@contextlib.contextmanager
def patch_dtype_abbrs():
"""
Patch torch.utils._dtype_abbrs in order to support FX graph dump with int4 dtype used
by MXFP4.
"""
try:
from torch.utils._dtype_abbrs import dtype_abbrs
except ModuleNotFoundError:
yield
return
original_dtype_abbrs = dict(dtype_abbrs)
dtype_abbrs.update(
{
torch.int4: "i4",
}
)
try:
yield
finally:
dtype_abbrs.clear()
dtype_abbrs.update(original_dtype_abbrs)
@contextlib.contextmanager
def patch_dtype_to_type():
"""This patch tries to fix the FX graph tracing issue when int4 dtype is used.
For example, the `torch.cat` fails with int4 tensors because the dtype_to_type
function in torch._prims_common does not support int4 dtype.
"""
try:
from torch import _prims_common
except ModuleNotFoundError:
yield
return
current_dtype_to_type = _prims_common.dtype_to_type
if hasattr(current_dtype_to_type, _DTYPE_TO_TYPE_ORIGINAL_ATTR):
yield
return
original_dtype_to_type = current_dtype_to_type
def dtype_to_type_patched(dtype: torch.dtype) -> type:
if dtype == torch.int4:
return int
return original_dtype_to_type(dtype)
setattr(dtype_to_type_patched, _DTYPE_TO_TYPE_ORIGINAL_ATTR, original_dtype_to_type)
_prims_common.dtype_to_type = dtype_to_type_patched
try:
yield
finally:
if _prims_common.dtype_to_type is dtype_to_type_patched:
_prims_common.dtype_to_type = original_dtype_to_type
@contextlib.contextmanager
def patch_masked_scatter():
"""Patch Tensor.masked_scatter to work with meta device tensors."""
try:
original_masked_scatter = torch.Tensor.masked_scatter
def masked_scatter_meta_safe(self, mask, source):
if isinstance(self, torch.Tensor) and self.device.type == "meta":
run_once(
"tensor_cast.patch_torch.masked_scatter.meta",
logger.warning,
"TensorCast: masked_scatter on meta is bypassed (returns empty_like); "
"shape/dtype preserved and op time is ~0.",
)
return torch.empty_like(self)
return original_masked_scatter(self, mask, source)
torch.Tensor.masked_scatter = masked_scatter_meta_safe
except Exception:
yield
return
try:
yield
finally:
torch.Tensor.masked_scatter = original_masked_scatter
@contextlib.contextmanager
def patch_torch():
"""
Apply all patches to PyTorch.
"""
with (
support_autocast_for_meta(),
specialize_float(),
patch_fallback_node_due_to_unsupported_type(),
patch_dtype_abbrs(),
patch_dtype_to_type(),
prepare_freezing(),
patch_masked_scatter(),
):
yield