import hashlib
import inspect
from abc import ABC, abstractmethod
from typing import Any
import torch
class TensorCastGraphModulePass(ABC):
"""Use the same interface as Inductor's CustomGraphPass"""
@abstractmethod
def __call__(self, graph: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Implementation of the custom pass.
"""
def uuid(self) -> Any:
"""
Provide a unique identifier for the pass, used for code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
hasher = hashlib.sha256()
src = inspect.getsource(self.__class__)
hasher.update(src.encode("utf-8"))
return hasher.hexdigest()