# Copyright 2026 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

# pyre-unsafe





import os

import json

import argparse

from concurrent.futures import ProcessPoolExecutor, as_completed

import multiprocessing as mp



import numpy as np

from PIL import Image

from tqdm import tqdm

import torch

from pycocotools import mask as mask_util

from ais_bench.infer.interface import InferSession

from transformers.models.sam3 import Sam3Processor



# ========== Global Configuration ==========

SEQ_LEN = 32

MASK_SIZE = 288

NUM_QUERIES = 200



# Ground truth file definitions per subset

saco_gold_gts = {

    "metaclip": [

        "gold_metaclip_merged_a_release_test.json",

        "gold_metaclip_merged_b_release_test.json",

        "gold_metaclip_merged_c_release_test.json",

    ],

    "sa1b": [

        "gold_sa1b_merged_a_release_test.json",

        "gold_sa1b_merged_b_release_test.json",

        "gold_sa1b_merged_c_release_test.json",

    ],

    "crowded": [

        "gold_crowded_merged_a_release_test.json",

        "gold_crowded_merged_b_release_test.json",

        "gold_crowded_merged_c_release_test.json",

    ],

    "fg_food": [

        "gold_fg_food_merged_a_release_test.json",

        "gold_fg_food_merged_b_release_test.json",

        "gold_fg_food_merged_c_release_test.json",

    ],

    "fg_sports_equipment": [

        "gold_fg_sports_equipment_merged_a_release_test.json",

        "gold_fg_sports_equipment_merged_b_release_test.json",

        "gold_fg_sports_equipment_merged_c_release_test.json",

    ],

    "attributes": [

        "gold_attributes_merged_a_release_test.json",

        "gold_attributes_merged_b_release_test.json",

        "gold_attributes_merged_c_release_test.json",

    ],

    "wiki_common": [

        "gold_wiki_common_merged_a_release_test.json",

        "gold_wiki_common_merged_b_release_test.json",

        "gold_wiki_common_merged_c_release_test.json",

    ],

}





def load_processor(model_path):

    """Load the SAM3 processor."""

    print(f"Loading processor from: {model_path}")

    processor = Sam3Processor.from_pretrained(model_path, local_files_only=True)

    print("Processor loaded successfully")

    return processor





def load_om_model(model_path, device_id):

    """Load OM model on the specified device."""

    print(f"Loading OM model: {model_path}, device={device_id}")

    session = InferSession(device_id, model_path)

    print("OM model loaded successfully")

    return session





def collect_samples(gt_json_path, image_root):

    """

    Collect evaluation samples from ground truth JSON.



    Returns:

        list of tuples: (image_path, text_input, image_id, height, width, category_id)

    """

    with open(gt_json_path, 'r') as f:

        data = json.load(f)



    samples = []

    for img in data['images']:

        if not img.get('is_instance_exhaustive', False):

            continue

        img_id = img['id']

        file_name = img['file_name']

        text_input = img.get('text_input', '')

        width, height = img['width'], img['height']

        if not text_input:

            continue

        img_path = os.path.join(image_root, file_name)

        if not os.path.exists(img_path):

            continue

        samples.append((img_path, text_input, img_id, height, width, img.get('queried_category', 1)))

    return samples





def rle_encode(orig_mask, return_areas=False):

    """Encodes a collection of masks in RLE format



    This function emulates the behavior of the COCO API's encode function, but

    is executed partially on the GPU for faster execution.



    Args:

        mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool

        return_areas (bool): If True, add the areas of the masks as a part of

            the RLE output dict under the "area" key. Default is False.



    Returns:

        str: The RLE encoded masks

    """

    if orig_mask.ndim != 3:

        raise ValueError(f"Mask must be of shape (N, H, W), got {orig_mask.ndim} dimensions")

    if orig_mask.dtype != torch.bool:

        raise TypeError(f"Mask must have dtype=torch.bool, got {orig_mask.dtype}")



    if orig_mask.numel() == 0:

        return []



    # First, transpose the spatial dimensions.

    # This is necessary because the COCO API uses Fortran order

    mask = orig_mask.transpose(1, 2)



    # Flatten the mask

    flat_mask = mask.reshape(mask.shape[0], -1)

    if return_areas:

        mask_areas = flat_mask.sum(-1).tolist()

    # Find the indices where the mask changes

    differences = torch.ones(

        mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool

    )

    differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]

    differences[:, 0] = flat_mask[:, 0]

    _, change_indices = torch.where(differences)



    try:

        boundaries = torch.cumsum(differences.sum(-1), 0).cpu()

    except RuntimeError as _:

        boundaries = torch.cumsum(differences.cpu().sum(-1), 0)



    change_indices_clone = change_indices.clone()

    # First pass computes the RLEs on GPU, in a flatten format

    for i in range(mask.shape[0]):

        # Get the change indices for this batch item

        beg = 0 if i == 0 else boundaries[i - 1].item()

        end = boundaries[i].item()

        change_indices[beg + 1: end] -= change_indices_clone[beg: end - 1]



    # Now we can split the RLES of each batch item, and convert them to strings

    # No more gpu at this point

    change_indices = change_indices.tolist()



    batch_rles = []

    # Process each mask in the batch separately

    for i in range(mask.shape[0]):

        beg = 0 if i == 0 else boundaries[i - 1].item()

        end = boundaries[i].item()

        run_lengths = change_indices[beg:end]



        uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}

        h, w = uncompressed_rle["size"]

        rle = mask_util.frPyObjects(uncompressed_rle, h, w)

        rle["counts"] = rle["counts"].decode("utf-8")

        if return_areas:

            rle["area"] = mask_areas[i]

        batch_rles.append(rle)



    return batch_rles





def prepare_batch_inputs(batch_samples, processor, model_batch_size):

    """

    Prepare input tensors for the model.



    Returns:

        pixel_values, input_ids, attention_mask, actual_batch_size

    """

    actual_bs = len(batch_samples)



    # Image preprocessing

    images = [Image.open(s[0]).convert('RGB') for s in batch_samples]

    pixel_values = processor.image_processor(images, return_tensors="pt")["pixel_values"]



    # Text preprocessing

    texts = [s[1] for s in batch_samples]

    tokenizer = processor.tokenizer

    text_inputs = tokenizer(

        texts,

        return_tensors="pt",

        truncation=True,

        padding='max_length',

        max_length=SEQ_LEN

    )

    input_ids = text_inputs["input_ids"]

    attn_mask = text_inputs["attention_mask"]



    # Adjust sequence length

    cur_len = input_ids.shape[1]

    if cur_len > SEQ_LEN:

        input_ids = input_ids[:, :SEQ_LEN]

        attn_mask = attn_mask[:, :SEQ_LEN]

    elif cur_len < SEQ_LEN:

        pad_len = SEQ_LEN - cur_len

        input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=0)

        attn_mask = torch.nn.functional.pad(attn_mask, (0, pad_len), value=0)



    # Pad to fixed batch size if necessary

    if actual_bs < model_batch_size:

        pad_num = model_batch_size - actual_bs

        pixel_values = torch.cat([pixel_values, pixel_values[0:1].repeat(pad_num, 1, 1, 1)], dim=0)

        input_ids = torch.cat([input_ids, input_ids[0:1].repeat(pad_num, 1)], dim=0)

        attn_mask = torch.cat([attn_mask, attn_mask[0:1].repeat(pad_num, 1)], dim=0)



    return pixel_values, input_ids, attn_mask, actual_bs





def model_infer(session, inputs):

    """Run inference with the OM session."""

    return session.infer(inputs)





def process_batch_core(session, processor, pixel_values, input_ids, attn_mask,

                       batch_samples, actual_batch_size, model_batch_size, threshold):

    """

    Core inference and post-processing for a batch.



    Returns:

        List of COCO-style predictions.

    """

    # Convert to numpy and run inference

    pixel_np = np.ascontiguousarray(pixel_values.numpy().astype(np.float32))

    input_np = np.ascontiguousarray(input_ids.numpy().astype(np.int64))

    attn_np = np.ascontiguousarray(attn_mask.numpy().astype(np.int64))



    outputs = model_infer(session, [pixel_np, input_np, attn_np])



    # Parse outputs

    pred_masks = torch.from_numpy(outputs[0].reshape(model_batch_size, NUM_QUERIES, MASK_SIZE, MASK_SIZE))

    pred_logits = torch.from_numpy(outputs[1].reshape(model_batch_size, NUM_QUERIES))

    pred_boxes = torch.from_numpy(outputs[2].reshape(model_batch_size, NUM_QUERIES, 4))

    presence_logits = torch.from_numpy(outputs[3].reshape(model_batch_size, 1))

    semantic_seg = torch.from_numpy(outputs[4].reshape(model_batch_size, 1, MASK_SIZE, MASK_SIZE))



    # Trim to actual batch size

    pred_masks = pred_masks[:actual_batch_size]

    pred_logits = pred_logits[:actual_batch_size]

    pred_boxes = pred_boxes[:actual_batch_size]

    presence_logits = presence_logits[:actual_batch_size]

    semantic_seg = semantic_seg[:actual_batch_size]



    target_sizes = [[s[3], s[4]] for s in batch_samples]  # (height, width)



    # Dummy class to hold outputs as expected by the processor

    class ModelOutputs:

        def __init__(self, pred_masks, pred_logits, pred_boxes, presence_logits, semantic_seg):

            self.pred_masks = pred_masks

            self.pred_logits = pred_logits

            self.pred_boxes = pred_boxes

            self.presence_logits = presence_logits

            self.semantic_seg = semantic_seg



    outputs_obj = ModelOutputs(

        pred_masks,

        pred_logits,

        pred_boxes,

        presence_logits,

        semantic_seg

    )



    batch_results = processor.post_process_instance_segmentation(

        outputs_obj,

        threshold=threshold,

        mask_threshold=threshold,

        target_sizes=target_sizes

    )



    # Convert to COCO format

    all_preds = []

    for i, result in enumerate(batch_results):

        img_id = batch_samples[i][2]

        masks = result.get('masks', [])

        scores = result.get('scores', []).tolist()

        boxes = result.get('boxes', []).tolist()



        if len(scores) == 0:

            continue

        masks_rle = rle_encode(masks.bool())



        for j, score in enumerate(scores):

            all_preds.append({

                "image_id": img_id,

                "category_id": 1,

                "bbox": boxes[j],

                "score": score,

                "segmentation": masks_rle[j]

            })



    return all_preds





def worker_process(args_tuple):

    """

    Worker function for a single subset. Executed in a separate process.

    """

    (subset_key, gt_file, img_root, device_id,

     model_path, processor_path, batch_size, threshold, output_dir) = args_tuple



    print(f"[Device {device_id}] Processing subset: {subset_key}")

    try:

        processor = load_processor(processor_path)

        session = load_om_model(model_path, device_id)



        samples = collect_samples(gt_file, img_root)

        if not samples:

            print(f"[Device {device_id}] No valid samples for {subset_key}, skipping.")

            return subset_key



        all_preds = []

        total_samples = len(samples)



        with tqdm(total=total_samples, desc=f"[Device {device_id}] {subset_key}", unit='img') as pbar:

            for i in range(0, len(samples), batch_size):

                batch = samples[i:i + batch_size]

                pixel_values, input_ids, attn_mask, actual_bs = prepare_batch_inputs(

                    batch, processor, batch_size

                )

                preds = process_batch_core(

                    session, processor, pixel_values, input_ids, attn_mask,

                    batch, actual_bs, batch_size, threshold

                )

                all_preds.extend(preds)

                pbar.update(len(batch))



        os.makedirs(output_dir, exist_ok=True)

        output_file = os.path.join(output_dir, f'predictions_{subset_key}.json')

        with open(output_file, 'w') as f:

            json.dump(all_preds, f)



        print(f"[Device {device_id}] Subset {subset_key} finished. Saved {len(all_preds)} predictions -> {output_file}")

        return subset_key

    except Exception as e:

        import traceback

        print(f"[Device {device_id}] Subset {subset_key} failed: {e}\n{traceback.format_exc()}")

        return None





def main():

    parser = argparse.ArgumentParser(description='SAM3 OM Model Evaluation')

    parser.add_argument('--dataset_root', type=str, required=True,

                        help='Root directory of the dataset')

    parser.add_argument('--model_path', type=str, required=True,

                        help='Path to the OM model file')

    parser.add_argument('--processor_path', type=str, required=True,

                        help='Path to the SAM3 processor (HuggingFace model)')

    parser.add_argument('--output_dir', type=str, default='./gold_predictions',

                        help='Output directory for prediction JSON files')

    parser.add_argument('--device', type=str, default='0',

                        help='Comma-separated list of device IDs (e.g., "0,1,2,3")')

    parser.add_argument('--batch_size', type=int, default=4,

                        help='Inference batch size per device')

    parser.add_argument('--procs_per_card', type=int, default=1,

                        help='Maximum number of concurrent processes per NPU')

    parser.add_argument('--threshold', type=float, default=0.5,

                        help='Score threshold for mask and box filtering')



    args = parser.parse_args()



    devices = [int(d.strip()) for d in args.device.split(',') if d.strip()]

    if not devices:

        raise ValueError("At least one device must be specified.")



    print(f"Available devices: {devices}, max processes per card: {args.procs_per_card}")



    gt_dir = os.path.join(args.dataset_root, 'gt-annotations')



    # Build task list

    tasks = []

    dev_idx = 0

    for subset_key, gt_file_list in saco_gold_gts.items():

        gt_file = os.path.join(gt_dir, gt_file_list[0])

        if not os.path.exists(gt_file):

            print(f"Skipping {subset_key}: GT file does not exist")

            continue



        if subset_key == "sa1b":

            img_root = os.path.join(args.dataset_root, 'sa1b-images')

        else:

            img_root = os.path.join(args.dataset_root, 'metaclip-images')



        if not os.path.exists(img_root):

            print(f"Skipping {subset_key}: image directory does not exist")

            continue



        device_id = devices[dev_idx % len(devices)]

        dev_idx += 1



        tasks.append((

            subset_key, gt_file, img_root, device_id,

            args.model_path, args.processor_path,

            args.batch_size, args.threshold, args.output_dir

        ))



    if not tasks:

        print("No valid tasks to process.")

        return



    # Group tasks by device

    tasks_by_device = {d: [] for d in devices}

    for t in tasks:

        device_id = t[3]

        tasks_by_device[device_id].append(t)



    mp.set_start_method('spawn', force=True)



    executors = {}

    all_futures = []



    for dev, dev_tasks in tasks_by_device.items():

        if not dev_tasks:

            continue

        print(f"Device {dev}: {len(dev_tasks)} task(s), max workers: {args.procs_per_card}")

        executor = ProcessPoolExecutor(

            max_workers=args.procs_per_card,

            mp_context=mp.get_context('spawn')

        )

        executors[dev] = executor

        for task in dev_tasks:

            future = executor.submit(worker_process, task)

            all_futures.append(future)



    # Wait for all tasks to complete

    for future in tqdm(as_completed(all_futures), total=len(all_futures), desc="Overall progress"):

        try:

            result = future.result()

            if result:

                print(f"Completed: {result}")

        except Exception as e:

            print(f"Task failed: {e}")



    # Shutdown all executors

    for executor in executors.values():

        executor.shutdown(wait=True)



    print("All subsets processed successfully.")





if __name__ == '__main__':

    main()