import logging

import torch

from mindspeed.fsdp.utils.log import print_rank
from .flash_attn.flash_attn import apply_transformers_attention_patch
from .fully_shard.fully_shard import apply_fully_shard_patch


logger = logging.getLogger(__name__)


def apply_ops_patch():
    apply_transformers_attention_patch()
    print_rank(logger.info, "✅ MindSpeed-MM ops patch applied.")

    # apply modify fully_shard patch
    apply_fully_shard_patch()
    print_rank(logger.info, "✅ MindSpeed-MM fully_shard patch applied.")