import torch
import torch.fx
from torch.fx.node import Node


def get_node_shape(node: Node) -> torch.Size | None:
    """Retrieve the shape of the tensor represented by the node, if available."""
    if not hasattr(node, "meta") or "val" not in node.meta:
        return None
    if isinstance(node.meta["val"], torch.Tensor):
        return node.meta["val"].shape
    else:
        return None


def is_non_scalar_tensor_node(node: Node) -> bool:
    """Check if a node represents a non-scalar tensor."""
    if not hasattr(node, "meta") or "val" not in node.meta:
        return False
    val = node.meta["val"]
    return isinstance(val, torch.Tensor) and len(val.shape) > 0


def maybe_copy_meta(target_node: Node, source_node: Node):
    """Copy meta information from source_node to target_node if available."""
    if hasattr(source_node, "meta"):
        target_node.meta = dict(source_node.meta)