import torch
from torch._dynamo.utils import detect_fake_mode
from torch.fx.passes.shape_prop import ShapeProp


def shape_propagation(gm: torch.fx.GraphModule, inputs) -> torch.fx.GraphModule:
    ShapeProp(
        gm=gm,
        fake_mode=detect_fake_mode(inputs),
    ).propagate(*tuple(inputs))
    return gm