YYour Nameadd openvino
1def4e0e创建于 2025年12月15日历史提交
# -*- coding: utf-8 -*-
import numpy as np
import cv2
import torch.nn.functional as F
import torch
from rknn.api import RKNN
import base64
import json
import io
from PIL import Image
import argparse
import sys
import traceback
from typing import List, Optional, Union

from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel

# ----------------------------
# Global Config
# ----------------------------
IMG_SIZE = 448
MASK_THRESHOLD = 0.0

# 全局模型实例
mobilesam_runner = None

app = FastAPI(title="MobileSAM RKNN API", description="基于 RKNN 的 MobileSAM 分割服务")


# ----------------------------
# Pydantic Models for Validation
# ----------------------------
class SegmentRequest(BaseModel):
    image_base64: str
    point_coords: List[List[float]]
    point_labels: List[int]
    mask_input: Optional[List] = None  # 可选,用于历史掩码


# ----------------------------
# Utility Functions (same logic as Flask version)
# ----------------------------
def get_preprocess_shape(oldh, oldw):
    scale = IMG_SIZE * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    newh, neww = int(newh + 0.5), int(neww + 0.5)
    return (neww, newh)  # (width, height)


def img_preprocess(img, ori_shape):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    oldh, oldw = ori_shape
    neww, newh = get_preprocess_shape(oldh, oldw)
    padh, padw = IMG_SIZE - newh, IMG_SIZE - neww
    scale = IMG_SIZE * 1.0 / max(oldh, oldw)

    img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_LINEAR)
    img = cv2.copyMakeBorder(img, 0, padh, 0, padw, cv2.BORDER_CONSTANT, value=(0, 0, 0))
    img = np.array([img]).astype(np.float32)
    return img, (neww, newh), (padw, padh), scale


def coords_preprocess(coords, new_size, scale):
    coords = coords.copy()
    coords[..., 0] = coords[..., 0] * scale
    coords[..., 1] = coords[..., 1] * scale
    neww, newh = new_size
    coords[..., 0] = np.clip(coords[..., 0], 0, neww - 1)
    coords[..., 1] = np.clip(coords[..., 1], 0, newh - 1)
    return coords


def postprocess(masks, ori_shape, new_size, pad_info):
    neww, newh = new_size
    masks = F.interpolate(torch.from_numpy(masks), (IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
    masks = masks[..., :newh, :neww]
    masks = F.interpolate(masks, (ori_shape[0], ori_shape[1]), mode='bilinear', align_corners=False)
    return masks.numpy()


def draw(images, masks, coords, labels, color=(144, 144, 30)):
    alpha = 0.5
    h, w = masks.shape[-2:]
    color_arr = np.array(color).reshape(1, 1, -1).astype(np.uint8)
    mask_image = masks.reshape(h, w, 1) * color_arr
    blended = np.where(
        mask_image != 0,
        cv2.addWeighted(images.astype(np.float32), alpha, mask_image.astype(np.float32), 1 - alpha, 0),
        images
    )

    top_left = bottom_right = None
    for coord, label in zip(coords, labels):
        x, y = int(coord[0]), int(coord[1])
        if label == 0:
            cv2.circle(blended, (x, y), 12, color=(0, 0, 255), thickness=2)
        elif label == 1:
            cv2.circle(blended, (x, y), 12, color=(0, 255, 0), thickness=2)
        elif label == 2:
            top_left = (x, y)
        elif label == 3:
            bottom_right = (x, y)

    if top_left and bottom_right:
        cv2.rectangle(blended, top_left, bottom_right, (0, 255, 0), 2)

    _, buffer = cv2.imencode('.jpg', blended.astype(np.uint8))
    return base64.b64encode(buffer).decode('utf-8')


# ----------------------------
# MobileSAM Runner Class
# ----------------------------
class MobileSAMRunner:
    def __init__(self, encoder_path: str, decoder_path: str, target_platform: str):
        self.encoder_rknn = self._init_rknn(encoder_path, target_platform)
        self.decoder_rknn = self._init_rknn(decoder_path, target_platform)

    def _init_rknn(self, model_path: str, target: str):
        rknn = RKNN(verbose=False)
        ret = rknn.load_rknn(model_path)
        if ret != 0:
            raise RuntimeError(f"Load RKNN model failed! path={model_path}, ret={ret}")
        ret = rknn.init_runtime(target=target)
        if ret != 0:
            raise RuntimeError(f"Init RKNN runtime failed! target={target}, ret={ret}")
        return rknn

    def pad_to_16bytes(self, data):
        data = np.array(data, dtype=np.float32)
        pad_num = (16 - data.nbytes % 16) % 16
        if pad_num > 0:
            pad_data = np.zeros((pad_num // data.itemsize,), dtype=data.dtype)
            if len(data.shape) == 3:  # [1, N, 2]
                pad_data = pad_data.reshape(1, -1, 2)[:, :data.shape[1], :]
                data = np.concatenate([data, pad_data], axis=1)
            elif len(data.shape) == 2:  # [1, N]
                pad_data = pad_data.reshape(1, -1)[:, :data.shape[1]]
                data = np.concatenate([data, pad_data], axis=1)
        return np.ascontiguousarray(data)

    def run(self, img, point_coords, point_labels, ori_shape, mask_input=None):
        # Encoder
        img_input, new_size, pad_info, scale = img_preprocess(img, ori_shape)
        img_embeds = self.encoder_rknn.inference(inputs=[img_input])[0]

        # Decoder inputs
        point_coords = coords_preprocess(point_coords[None, :, :], new_size, scale)
        point_labels = point_labels[None, :].astype(np.float32)

        if mask_input is not None:
            mask_input = np.array(mask_input, dtype=np.float32)
            has_mask_input = np.ones((1, 1), dtype=np.float32)
        else:
            mask_input = np.zeros((1, 1, 112, 112), dtype=np.float32)
            has_mask_input = np.zeros((1, 1), dtype=np.float32)

        # 关键:内存对齐(解决 RKNN_ERR_PARAM_INVALID)
        point_coords = self.pad_to_16bytes(point_coords)
        point_labels = self.pad_to_16bytes(point_labels)
        mask_input = np.ascontiguousarray(mask_input)
        has_mask_input = np.ascontiguousarray(has_mask_input)
        img_embeds = np.ascontiguousarray(img_embeds)

        outputs = self.decoder_rknn.inference(
            inputs=[img_embeds, point_coords, point_labels, mask_input, has_mask_input],
            data_format='NCHW'
        )

        iou_predictions, low_res_masks = outputs[0], outputs[1]
        return iou_predictions, low_res_masks, new_size, pad_info

    def release(self):
        if self.encoder_rknn:
            self.encoder_rknn.release()
        if self.decoder_rknn:
            self.decoder_rknn.release()


# ----------------------------
# API Endpoint
# ----------------------------
@app.post("/segment")
async def segment(request: Request):
    global mobilesam_runner
    if mobilesam_runner is None:
        raise HTTPException(status_code=500, detail="Model not initialized")

    try:
        body = await request.json()
        req = SegmentRequest(**body)
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")

    # Decode image
    try:
        img_data = base64.b64decode(req.image_base64)
        img_pil = Image.open(io.BytesIO(img_data)).convert("RGB")
        img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
        ori_shape = img.shape[:2]
        if ori_shape[0] == 0 or ori_shape[1] == 0:
            raise ValueError("Invalid image size")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Image decode error: {e}")

    # Validate points
    point_coords = np.array(req.point_coords, dtype=np.float32)
    point_labels = np.array(req.point_labels, dtype=np.int32)

    if len(point_coords.shape) != 2 or point_coords.shape[1] != 2:
        raise HTTPException(status_code=400, detail="point_coords must be (N, 2)")
    if len(point_labels.shape) != 1:
        raise HTTPException(status_code=400, detail="point_labels must be (N,)")
    if len(point_coords) != len(point_labels):
        raise HTTPException(status_code=400, detail="point_coords and point_labels length mismatch")
    if len(point_coords) == 0:
        raise HTTPException(status_code=400, detail="point_coords cannot be empty")

    # Run inference
    try:
        scores, low_res_masks, new_size, pad_info = mobilesam_runner.run(
            img=img,
            point_coords=point_coords,
            point_labels=point_labels,
            ori_shape=ori_shape,
            mask_input=req.mask_input
        )
    except Exception as e:
        print(f"Inference error: {e}\n{traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=f"Inference failed: {e}")

    # Postprocess
    masks = postprocess(low_res_masks, ori_shape, new_size, pad_info)
    masks = masks > MASK_THRESHOLD
    best_idx = np.argmax(scores)
    final_mask = masks[0, best_idx]

    # Draw and encode
    result_b64 = draw(img, final_mask, point_coords, point_labels)

    return JSONResponse({
        "code": 0,
        "msg": "success",
        "data": {
            "result_base64": result_b64,
            "confidence": float(np.max(scores))
        }
    })


@app.get("/health")
def health():
    return {"status": "OK", "model_loaded": mobilesam_runner is not None}


# ----------------------------
# Service Initialization
# ----------------------------
def init_service(encoder_path: str, decoder_path: str, target_platform: str):
    global mobilesam_runner
    try:
        mobilesam_runner = MobileSAMRunner(encoder_path, decoder_path, target_platform)
        print("✅ MobileSAM RKNN service initialized successfully")
    except Exception as e:
        print(f"❌ Failed to initialize model: {e}")
        traceback.print_exc()
        sys.exit(1)


# ----------------------------
# Main Entry
# ----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='MobileSAM RKNN FastAPI Service')
    parser.add_argument('--encoder', type=str, required=True, help="Encoder .rknn path")
    parser.add_argument('--decoder', type=str, required=True, help="Decoder .rknn path")
    parser.add_argument('--target', type=str, required=True, help="Target platform (e.g., rk3588)")
    parser.add_argument('--host', type=str, default='0.0.0.0')
    parser.add_argument('--port', type=int, default=8000)
    args = parser.parse_args()

    init_service(args.encoder, args.decoder, args.target)

    import uvicorn
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")