YYour Nameadd openvino
1def4e0e创建于 2025年12月15日历史提交
# -*- coding: utf-8 -*-
import numpy as np
import cv2
import torch.nn.functional as F
import torch
from rknn.api import RKNN
import base64
import json
from flask import Flask, request, jsonify
import io
from PIL import Image
import argparse

# Global configuration
IMG_SIZE = 448
MASK_THRESHOLD = 0.0
app = Flask(__name__)

# Global MobileSAM instance (avoid repeated initialization)
mobilesam_instance = None

# -------------------------- Utility Functions --------------------------
def get_preprocess_shape(oldh, oldw):
    scale = IMG_SIZE * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    newh, neww = int(newh + 0.5), int(neww + 0.5)
    return (neww, newh)  # (width, height) to match cv2.resize order

def img_preprocess(img, ori_shape):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    oldh, oldw = ori_shape
    neww, newh = get_preprocess_shape(oldh, oldw)
    padh, padw = IMG_SIZE - newh, IMG_SIZE - neww
    
    scale = IMG_SIZE * 1.0 / max(oldh, oldw)

    img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_LINEAR)
    img = cv2.copyMakeBorder(img, 0, padh, 0, padw, cv2.BORDER_CONSTANT, value=(0, 0, 0))
    img = np.array([img]).astype(np.float32)
    
    return img, (neww, newh), (padw, padh), scale

def coords_preprocess(coords, ori_shape, new_size, scale):
    oldh, oldw = ori_shape
    neww, newh = new_size
    
    coords[..., 0] = coords[..., 0] * scale
    coords[..., 1] = coords[..., 1] * scale
    
    coords[..., 0] = np.clip(coords[..., 0], 0, neww - 1)
    coords[..., 1] = np.clip(coords[..., 1], 0, newh - 1)
    return coords

def postprocess(masks, ori_shape, new_size, pad_info):
    neww, newh = new_size
    padw, padh = pad_info
    
    masks = F.interpolate(torch.from_numpy(masks), (IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
    masks = masks[..., :newh, :neww]
    
    masks = F.interpolate(masks, (ori_shape[0], ori_shape[1]), mode='bilinear', align_corners=False)
    return masks.numpy()

def draw(images, masks, coords, labels, color=(144, 144, 30)):
    alpha = 0.5
    h, w = masks.shape[-2:]
    color = np.array(color)
    mask_image = masks.reshape(h, w, 1) * color.reshape(1, 1, -1).astype(np.uint8)
    images = np.where(mask_image != 0, cv2.addWeighted(images.astype(np.float32), alpha, mask_image.astype(np.float32), (1 - alpha), 0), images)

    top, left, right, bottom = None, None, None, None
    for coord, label in zip(coords, labels):
        if label == 0:
            cv2.circle(images, tuple(map(int, coord)), 12, color=(0, 0, 255), thickness=2)
        elif label == 1:
            cv2.circle(images, tuple(map(int, coord)), 12, color=(0, 255, 0), thickness=2)
        elif label in [2, 3]:
            if label == 2:
                top, left = map(int, coord)
            elif label == 3:
                right, bottom = map(int, coord)
            if top is not None and left is not None and right is not None and bottom is not None:
                cv2.rectangle(images, (top, left), (right, bottom), (0, 255, 0), 2)

    _, buffer = cv2.imencode('.jpg', images)
    img_base64 = base64.b64encode(buffer).decode('utf-8')
    return img_base64

# -------------------------- MobileSAM Class --------------------------
class MobileSAM():
    def __init__(self, encoder_path, decoder_path, target_platform):
        self.encoder = encoder_path
        self.decoder = decoder_path
        self.target = target_platform
        self.encoder_rknn = self._init_rknn(self.encoder)
        self.decoder_rknn = self._init_rknn(self.decoder)

    def _init_rknn(self, model_path):
        rknn = RKNN(verbose=False)
        ret = rknn.load_rknn(model_path)
        if ret != 0:
            raise RuntimeError(f"Load RKNN model failed! path={model_path}, ret={ret}")
        ret = rknn.init_runtime(target=self.target)
        if ret != 0:
            raise RuntimeError(f"Init RKNN runtime failed! target={self.target}, ret={ret}")
        return rknn

    def encoder_run(self, img, ori_shape):
        img_input, new_size, pad_info, scale = img_preprocess(img, ori_shape)
        outputs = self.encoder_rknn.inference(inputs=[img_input])[0]
        return outputs, new_size, pad_info, scale

    def decoder_run(self, img_embeds, point_coords, point_labels, new_size, scale, mask_input=None):
        point_coords = coords_preprocess(point_coords[None, :, :], ori_shape=None, new_size=new_size, scale=scale)
        point_labels = point_labels[None, :].astype(np.float32)

        if mask_input is not None:
            mask_input = np.array(mask_input).astype(np.float32)
            has_mask_input = np.ones(1, dtype=np.float32)
        else:
            mask_input = np.zeros((1, 1, 112, 112), dtype=np.float32)
            has_mask_input = np.zeros(1, dtype=np.float32)

        point_coords = self.pad_to_16bytes(point_coords)
        point_labels = self.pad_to_16bytes(point_labels)

        outputs = self.decoder_rknn.inference(
            inputs=[img_embeds, point_coords, point_labels, mask_input, has_mask_input],
            data_format='NCHW'
        )
        return outputs[0], outputs[1]

    def pad_to_16bytes(self, data):
        data = np.array(data, dtype=np.float32)
        pad_num = (16 - data.nbytes % 16) % 16
        if pad_num > 0:
            pad_data = np.zeros((pad_num // data.itemsize,), dtype=data.dtype)
            if len(data.shape) == 3:
                pad_data = pad_data.reshape(1, -1, 2)[:,:data.shape[1],:]
                data = np.concatenate([data, pad_data], axis=1)
            elif len(data.shape) == 2:
                pad_data = pad_data.reshape(1, -1)[:,:data.shape[1]]
                data = np.concatenate([data, pad_data], axis=1)
        return data

    def run(self, img, point_coords, point_labels, ori_shape, mask_input=None):
        img_embeds, new_size, pad_info, scale = self.encoder_run(img, ori_shape)
        iou_predictions, low_res_masks = self.decoder_run(
            img_embeds, point_coords, point_labels, new_size, scale, mask_input
        )
        return iou_predictions, low_res_masks, new_size, pad_info

    def release(self):
        if self.encoder_rknn:
            self.encoder_rknn.release()
        if self.decoder_rknn:
            self.decoder_rknn.release()

# -------------------------- HTTP API --------------------------
@app.route('/segment', methods=['POST'])
def segment_image():
    global mobilesam_instance
    try:
        data = request.get_json()
        if not data:
            return jsonify({"code": -1, "msg": "Missing JSON data"}), 400

        required_params = ["image_base64", "point_coords", "point_labels"]
        for param in required_params:
            if param not in data:
                return jsonify({"code": -1, "msg": f"Missing required parameter: {param}"}), 400

        img_base64 = data["image_base64"]
        try:
            img_data = base64.b64decode(img_base64)
            img = Image.open(io.BytesIO(img_data))
            img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
            ori_shape = img.shape[:2]
            if ori_shape[0] == 0 or ori_shape[1] == 0:
                return jsonify({"code": -1, "msg": "Invalid image size"}), 400
        except Exception as e:
            return jsonify({"code": -1, "msg": f"Decode image failed: {str(e)}"}), 400

        try:
            point_coords = np.array(data["point_coords"], dtype=np.float32)
            point_labels = np.array(data["point_labels"], dtype=np.float32)
            if len(point_coords.shape) != 2 or point_coords.shape[1] != 2:
                return jsonify({"code": -1, "msg": "point_coords must be (N, 2) format"}), 400
            if len(point_labels.shape) != 1:
                return jsonify({"code": -1, "msg": "point_labels must be (N,) format"}), 400
            if len(point_coords) != len(point_labels):
                return jsonify({"code": -1, "msg": "point_coords and point_labels length mismatch"}), 400
            if len(point_coords) == 0:
                return jsonify({"code": -1, "msg": "point_coords cannot be empty"}), 400
        except Exception as e:
            return jsonify({"code": -1, "msg": f"Parse point data failed: {str(e)}"}), 400

        mask_input = data.get("mask_input")
        if mask_input is not None:
            try:
                mask_input = np.array(mask_input, dtype=np.float32)
            except Exception as e:
                return jsonify({"code": -1, "msg": f"Invalid mask_input: {str(e)}"}), 400

        try:
            scores, low_res_masks, new_size, pad_info = mobilesam_instance.run(
                img=img,
                point_coords=point_coords,
                point_labels=point_labels,
                ori_shape=ori_shape,
                mask_input=mask_input
            )
        except Exception as e:
            return jsonify({"code": -1, "msg": f"Inference failed: {str(e)}"}), 500

        masks = postprocess(low_res_masks, ori_shape, new_size, pad_info)
        masks = masks > MASK_THRESHOLD
        best_mask_idx = np.argmax(scores)
        masks = masks[:, best_mask_idx, :, :]

        result_base64 = draw(img, masks, point_coords, point_labels)

        return jsonify({
            "code": 0,
            "msg": "success",
            "data": {
                "result_base64": result_base64,
                "confidence": float(np.max(scores))
            }
        }), 200

    except Exception as e:
        import traceback
        error_msg = f"Server error: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return jsonify({"code": -1, "msg": error_msg}), 500

# -------------------------- Service Initialization --------------------------
def init_service(encoder_path, decoder_path, target_platform):
    global mobilesam_instance
    try:
        mobilesam_instance = MobileSAM(encoder_path, decoder_path, target_platform)
        print("MobileSAM service initialized successfully")
    except Exception as e:
        print(f"Init service failed: {str(e)}")
        raise

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MobileSAM RKNN HTTP Service')
    parser.add_argument('--encoder', type=str, required=True, help="MobileSAM encoder RKNN model path")
    parser.add_argument('--decoder', type=str, required=True, help="MobileSAM decoder RKNN model path")
    parser.add_argument('--target', type=str, required=True, help="Target platform (e.g., rk3588)")
    parser.add_argument('--host', type=str, default='0.0.0.0', help="HTTP service host")
    parser.add_argument('--port', type=int, default=5000, help="HTTP service port")
    args = parser.parse_args()

    init_service(args.encoder, args.decoder, args.target)

    app.run(host=args.host, port=args.port, debug=False, use_reloader=False)