import numpy as np
import cv2
import torch.nn.functional as F
import torch
from rknn.api import RKNN
import base64
import json
import io
from PIL import Image
import argparse
import sys
import traceback
from typing import List, Optional, Union
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
IMG_SIZE = 448
MASK_THRESHOLD = 0.0
mobilesam_runner = None
app = FastAPI(title="MobileSAM RKNN API", description="基于 RKNN 的 MobileSAM 分割服务")
class SegmentRequest(BaseModel):
image_base64: str
point_coords: List[List[float]]
point_labels: List[int]
mask_input: Optional[List] = None
def get_preprocess_shape(oldh, oldw):
scale = IMG_SIZE * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
newh, neww = int(newh + 0.5), int(neww + 0.5)
return (neww, newh)
def img_preprocess(img, ori_shape):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
oldh, oldw = ori_shape
neww, newh = get_preprocess_shape(oldh, oldw)
padh, padw = IMG_SIZE - newh, IMG_SIZE - neww
scale = IMG_SIZE * 1.0 / max(oldh, oldw)
img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_LINEAR)
img = cv2.copyMakeBorder(img, 0, padh, 0, padw, cv2.BORDER_CONSTANT, value=(0, 0, 0))
img = np.array([img]).astype(np.float32)
return img, (neww, newh), (padw, padh), scale
def coords_preprocess(coords, new_size, scale):
coords = coords.copy()
coords[..., 0] = coords[..., 0] * scale
coords[..., 1] = coords[..., 1] * scale
neww, newh = new_size
coords[..., 0] = np.clip(coords[..., 0], 0, neww - 1)
coords[..., 1] = np.clip(coords[..., 1], 0, newh - 1)
return coords
def postprocess(masks, ori_shape, new_size, pad_info):
neww, newh = new_size
masks = F.interpolate(torch.from_numpy(masks), (IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
masks = masks[..., :newh, :neww]
masks = F.interpolate(masks, (ori_shape[0], ori_shape[1]), mode='bilinear', align_corners=False)
return masks.numpy()
def draw(images, masks, coords, labels, color=(144, 144, 30)):
alpha = 0.5
h, w = masks.shape[-2:]
color_arr = np.array(color).reshape(1, 1, -1).astype(np.uint8)
mask_image = masks.reshape(h, w, 1) * color_arr
blended = np.where(
mask_image != 0,
cv2.addWeighted(images.astype(np.float32), alpha, mask_image.astype(np.float32), 1 - alpha, 0),
images
)
top_left = bottom_right = None
for coord, label in zip(coords, labels):
x, y = int(coord[0]), int(coord[1])
if label == 0:
cv2.circle(blended, (x, y), 12, color=(0, 0, 255), thickness=2)
elif label == 1:
cv2.circle(blended, (x, y), 12, color=(0, 255, 0), thickness=2)
elif label == 2:
top_left = (x, y)
elif label == 3:
bottom_right = (x, y)
if top_left and bottom_right:
cv2.rectangle(blended, top_left, bottom_right, (0, 255, 0), 2)
_, buffer = cv2.imencode('.jpg', blended.astype(np.uint8))
return base64.b64encode(buffer).decode('utf-8')
class MobileSAMRunner:
def __init__(self, encoder_path: str, decoder_path: str, target_platform: str):
self.encoder_rknn = self._init_rknn(encoder_path, target_platform)
self.decoder_rknn = self._init_rknn(decoder_path, target_platform)
def _init_rknn(self, model_path: str, target: str):
rknn = RKNN(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN model failed! path={model_path}, ret={ret}")
ret = rknn.init_runtime(target=target)
if ret != 0:
raise RuntimeError(f"Init RKNN runtime failed! target={target}, ret={ret}")
return rknn
def pad_to_16bytes(self, data):
data = np.array(data, dtype=np.float32)
pad_num = (16 - data.nbytes % 16) % 16
if pad_num > 0:
pad_data = np.zeros((pad_num // data.itemsize,), dtype=data.dtype)
if len(data.shape) == 3:
pad_data = pad_data.reshape(1, -1, 2)[:, :data.shape[1], :]
data = np.concatenate([data, pad_data], axis=1)
elif len(data.shape) == 2:
pad_data = pad_data.reshape(1, -1)[:, :data.shape[1]]
data = np.concatenate([data, pad_data], axis=1)
return np.ascontiguousarray(data)
def run(self, img, point_coords, point_labels, ori_shape, mask_input=None):
img_input, new_size, pad_info, scale = img_preprocess(img, ori_shape)
img_embeds = self.encoder_rknn.inference(inputs=[img_input])[0]
point_coords = coords_preprocess(point_coords[None, :, :], new_size, scale)
point_labels = point_labels[None, :].astype(np.float32)
if mask_input is not None:
mask_input = np.array(mask_input, dtype=np.float32)
has_mask_input = np.ones((1, 1), dtype=np.float32)
else:
mask_input = np.zeros((1, 1, 112, 112), dtype=np.float32)
has_mask_input = np.zeros((1, 1), dtype=np.float32)
point_coords = self.pad_to_16bytes(point_coords)
point_labels = self.pad_to_16bytes(point_labels)
mask_input = np.ascontiguousarray(mask_input)
has_mask_input = np.ascontiguousarray(has_mask_input)
img_embeds = np.ascontiguousarray(img_embeds)
outputs = self.decoder_rknn.inference(
inputs=[img_embeds, point_coords, point_labels, mask_input, has_mask_input],
data_format='NCHW'
)
iou_predictions, low_res_masks = outputs[0], outputs[1]
return iou_predictions, low_res_masks, new_size, pad_info
def release(self):
if self.encoder_rknn:
self.encoder_rknn.release()
if self.decoder_rknn:
self.decoder_rknn.release()
@app.post("/segment")
async def segment(request: Request):
global mobilesam_runner
if mobilesam_runner is None:
raise HTTPException(status_code=500, detail="Model not initialized")
try:
body = await request.json()
req = SegmentRequest(**body)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
try:
img_data = base64.b64decode(req.image_base64)
img_pil = Image.open(io.BytesIO(img_data)).convert("RGB")
img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
ori_shape = img.shape[:2]
if ori_shape[0] == 0 or ori_shape[1] == 0:
raise ValueError("Invalid image size")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Image decode error: {e}")
point_coords = np.array(req.point_coords, dtype=np.float32)
point_labels = np.array(req.point_labels, dtype=np.int32)
if len(point_coords.shape) != 2 or point_coords.shape[1] != 2:
raise HTTPException(status_code=400, detail="point_coords must be (N, 2)")
if len(point_labels.shape) != 1:
raise HTTPException(status_code=400, detail="point_labels must be (N,)")
if len(point_coords) != len(point_labels):
raise HTTPException(status_code=400, detail="point_coords and point_labels length mismatch")
if len(point_coords) == 0:
raise HTTPException(status_code=400, detail="point_coords cannot be empty")
try:
scores, low_res_masks, new_size, pad_info = mobilesam_runner.run(
img=img,
point_coords=point_coords,
point_labels=point_labels,
ori_shape=ori_shape,
mask_input=req.mask_input
)
except Exception as e:
print(f"Inference error: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
masks = postprocess(low_res_masks, ori_shape, new_size, pad_info)
masks = masks > MASK_THRESHOLD
best_idx = np.argmax(scores)
final_mask = masks[0, best_idx]
result_b64 = draw(img, final_mask, point_coords, point_labels)
return JSONResponse({
"code": 0,
"msg": "success",
"data": {
"result_base64": result_b64,
"confidence": float(np.max(scores))
}
})
@app.get("/health")
def health():
return {"status": "OK", "model_loaded": mobilesam_runner is not None}
def init_service(encoder_path: str, decoder_path: str, target_platform: str):
global mobilesam_runner
try:
mobilesam_runner = MobileSAMRunner(encoder_path, decoder_path, target_platform)
print("✅ MobileSAM RKNN service initialized successfully")
except Exception as e:
print(f"❌ Failed to initialize model: {e}")
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MobileSAM RKNN FastAPI Service')
parser.add_argument('--encoder', type=str, required=True, help="Encoder .rknn path")
parser.add_argument('--decoder', type=str, required=True, help="Decoder .rknn path")
parser.add_argument('--target', type=str, required=True, help="Target platform (e.g., rk3588)")
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument('--port', type=int, default=8000)
args = parser.parse_args()
init_service(args.encoder, args.decoder, args.target)
import uvicorn
uvicorn.run(app, host=args.host, port=args.port, log_level="info")