import argparse
from collections import deque
import logging
import os
import time
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from modeling_pi0_om import model_init
from utils import AclLiteResource
logger = logging.getLogger(__name__)
def main():
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
parser = argparse.ArgumentParser(description="Run pi0 OM inference end-to-end")
parser.add_argument("--vlm-model-path", default="outputs/om/pi0-vlm.om", help="Path to part1 OM file (PaliGemma)")
parser.add_argument(
"--action-expert-model-path",
default="outputs/om/pi0-action_expert.om",
help="Path to part2 OM file (Gemma)",
)
parser.add_argument("--device", default=0, help="device id")
parser.add_argument("--mean-path", default=None, help="Path to action mean .pt file")
parser.add_argument("--std-path", default=None, help="Path to action std .pt file")
args = parser.parse_args()
config = {
'output_features': {
'action': PolicyFeature(
type=FeatureType.ACTION,
shape=(14,))
},
'normalization_mapping': {
'VISUAL': NormalizationMode.IDENTITY,
'STATE': NormalizationMode.MEAN_STD,
'ACTION': NormalizationMode.MEAN_STD,
},
'stats': None,
}
if args.mean_path or args.std_path:
if not (args.mean_path and args.std_path):
raise ValueError("Both --mean-path and --std-path must be provided together.")
if not os.path.exists(args.mean_path):
raise FileNotFoundError(f"Mean file not found: {args.mean_path}")
if not os.path.exists(args.std_path):
raise FileNotFoundError(f"Std file not found: {args.std_path}")
loaded_mean = torch.load(args.mean_path)
loaded_std = torch.load(args.std_path)
config['stats'] = {
'action': {
'mean': loaded_mean,
'std': loaded_std,
}
}
else:
default_mean = torch.tensor([
-0.0054, -0.4803, 1.0102, -0.0042, -0.5298, 1.1214, 0.5875,
0.0196, -0.3138, 0.4702, -0.0231, 0.7722, 0.0375, 0.5962
], dtype=torch.float32)
default_std = torch.tensor([
0.0037, 0.5198, 0.1978, 0.0163, 0.3605, 0.5996, 0.4241,
0.1111, 0.4944, 0.4435, 0.1452, 0.2956, 0.2278, 0.3861
], dtype=torch.float32)
config['stats'] = {
'action': {
'mean': default_mean,
'std': default_std,
}
}
model = model_init(args.vlm_model_path, args.action_expert_model_path, config)
state_shape = (1, 14)
state = np.zeros(state_shape, dtype=np.float32)
image_shape = (1, 3, 480, 640)
image = np.zeros(image_shape, dtype=np.float32)
vocab_size = 32000
max_len = 48
lang_tokens = torch.randint(low=1, high=vocab_size, size=(1, max_len), dtype=torch.long)
lang_masks = torch.ones_like(lang_tokens, dtype=torch.bool)
model.select_action(state, image, lang_tokens, lang_masks)
model.reset()
t0 = time.time()
action = model.select_action(state, image, lang_tokens, lang_masks)
time_elapsed = time.time() - t0
logger.info("Action inference time: %.4f seconds", time_elapsed)
logger.info("Selected action: %s", action)
if __name__ == "__main__":
main()