# Copyright (c) 2025 Syslong Technology Co., Ltd. All Rights Reserved.
# Copyright (c) 2025 Shanghai Jiao Tong University
# Copyright (c) 2026, HUAWEI CORPORATION.  All rights reserved.
#
# Licensed under the Mulan PSL v2.
# You may obtain a copy of the License at:
#     http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import argparse
from collections import deque
import logging
import os
import time

import numpy as np
import torch
from torch import Tensor

from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from modeling_pi0_om import model_init
from utils import AclLiteResource

logger = logging.getLogger(__name__)


def main():
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
    parser = argparse.ArgumentParser(description="Run pi0 OM inference end-to-end")
    parser.add_argument("--vlm-model-path", default="outputs/om/pi0-vlm.om", help="Path to part1 OM file (PaliGemma)")
    parser.add_argument(
        "--action-expert-model-path",
        default="outputs/om/pi0-action_expert.om",
        help="Path to part2 OM file (Gemma)",
    )
    parser.add_argument("--device", default=0, help="device id")
    parser.add_argument("--mean-path", default=None, help="Path to action mean .pt file")
    parser.add_argument("--std-path", default=None, help="Path to action std .pt file")
    args = parser.parse_args()

    config = {
        'output_features': {
            'action': PolicyFeature(
                type=FeatureType.ACTION,
                shape=(14,))
        },
        'normalization_mapping': {
            'VISUAL': NormalizationMode.IDENTITY,
            'STATE': NormalizationMode.MEAN_STD,
            'ACTION': NormalizationMode.MEAN_STD,
        },
        'stats': None,
    }

    # Override stats with user-provided mean/std tensors when both paths are given.
    if args.mean_path or args.std_path:
        if not (args.mean_path and args.std_path):
            raise ValueError("Both --mean-path and --std-path must be provided together.")
        if not os.path.exists(args.mean_path):
            raise FileNotFoundError(f"Mean file not found: {args.mean_path}")
        if not os.path.exists(args.std_path):
            raise FileNotFoundError(f"Std file not found: {args.std_path}")

        loaded_mean = torch.load(args.mean_path)
        loaded_std = torch.load(args.std_path)
        config['stats'] = {
            'action': {
                'mean': loaded_mean,
                'std': loaded_std,
            }
        }
    else:
        # Default action normalization stats if no paths are supplied.
        default_mean = torch.tensor([
            -0.0054, -0.4803, 1.0102, -0.0042, -0.5298, 1.1214, 0.5875,
            0.0196, -0.3138, 0.4702, -0.0231, 0.7722, 0.0375, 0.5962
        ], dtype=torch.float32)
        default_std = torch.tensor([
            0.0037, 0.5198, 0.1978, 0.0163, 0.3605, 0.5996, 0.4241,
            0.1111, 0.4944, 0.4435, 0.1452, 0.2956, 0.2278, 0.3861
        ], dtype=torch.float32)
        config['stats'] = {
            'action': {
                'mean': default_mean,
                'std': default_std,
            }
        }
    # add your model paths here
    # model init params: vlm_model_path, action_expert_model_path, config
    model = model_init(args.vlm_model_path, args.action_expert_model_path, config)
    # def interface(self, state, image, tokens_ids):
    state_shape = (1, 14)
    state = np.zeros(state_shape, dtype=np.float32)

    image_shape = (1, 3, 480, 640)
    image = np.zeros(image_shape, dtype=np.float32)
    vocab_size = 32000
    max_len = 48
    lang_tokens = torch.randint(low=1, high=vocab_size, size=(1, max_len), dtype=torch.long)
    lang_masks = torch.ones_like(lang_tokens, dtype=torch.bool)

    # warm up
    model.select_action(state, image, lang_tokens, lang_masks)
    model.reset()

    # inference
    t0 = time.time()
    action = model.select_action(state, image, lang_tokens, lang_masks)
    time_elapsed = time.time() - t0
    logger.info("Action inference time: %.4f seconds", time_elapsed)
    logger.info("Selected action: %s", action)


if __name__ == "__main__":
    main()