import numpy as np
import cv2
import torch.nn.functional as F
import torch
from rknn.api import RKNN
import base64
import json
from flask import Flask, request, jsonify
import io
from PIL import Image
import argparse
IMG_SIZE = 448
MASK_THRESHOLD = 0.0
app = Flask(__name__)
mobilesam_instance = None
def get_preprocess_shape(oldh, oldw):
scale = IMG_SIZE * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
newh, neww = int(newh + 0.5), int(neww + 0.5)
return (neww, newh)
def img_preprocess(img, ori_shape):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
oldh, oldw = ori_shape
neww, newh = get_preprocess_shape(oldh, oldw)
padh, padw = IMG_SIZE - newh, IMG_SIZE - neww
scale = IMG_SIZE * 1.0 / max(oldh, oldw)
img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_LINEAR)
img = cv2.copyMakeBorder(img, 0, padh, 0, padw, cv2.BORDER_CONSTANT, value=(0, 0, 0))
img = np.array([img]).astype(np.float32)
return img, (neww, newh), (padw, padh), scale
def coords_preprocess(coords, ori_shape, new_size, scale):
oldh, oldw = ori_shape
neww, newh = new_size
coords[..., 0] = coords[..., 0] * scale
coords[..., 1] = coords[..., 1] * scale
coords[..., 0] = np.clip(coords[..., 0], 0, neww - 1)
coords[..., 1] = np.clip(coords[..., 1], 0, newh - 1)
return coords
def postprocess(masks, ori_shape, new_size, pad_info):
neww, newh = new_size
padw, padh = pad_info
masks = F.interpolate(torch.from_numpy(masks), (IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
masks = masks[..., :newh, :neww]
masks = F.interpolate(masks, (ori_shape[0], ori_shape[1]), mode='bilinear', align_corners=False)
return masks.numpy()
def draw(images, masks, coords, labels, color=(144, 144, 30)):
alpha = 0.5
h, w = masks.shape[-2:]
color = np.array(color)
mask_image = masks.reshape(h, w, 1) * color.reshape(1, 1, -1).astype(np.uint8)
images = np.where(mask_image != 0, cv2.addWeighted(images.astype(np.float32), alpha, mask_image.astype(np.float32), (1 - alpha), 0), images)
top, left, right, bottom = None, None, None, None
for coord, label in zip(coords, labels):
if label == 0:
cv2.circle(images, tuple(map(int, coord)), 12, color=(0, 0, 255), thickness=2)
elif label == 1:
cv2.circle(images, tuple(map(int, coord)), 12, color=(0, 255, 0), thickness=2)
elif label in [2, 3]:
if label == 2:
top, left = map(int, coord)
elif label == 3:
right, bottom = map(int, coord)
if top is not None and left is not None and right is not None and bottom is not None:
cv2.rectangle(images, (top, left), (right, bottom), (0, 255, 0), 2)
_, buffer = cv2.imencode('.jpg', images)
img_base64 = base64.b64encode(buffer).decode('utf-8')
return img_base64
class MobileSAM():
def __init__(self, encoder_path, decoder_path, target_platform):
self.encoder = encoder_path
self.decoder = decoder_path
self.target = target_platform
self.encoder_rknn = self._init_rknn(self.encoder)
self.decoder_rknn = self._init_rknn(self.decoder)
def _init_rknn(self, model_path):
rknn = RKNN(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
raise RuntimeError(f"Load RKNN model failed! path={model_path}, ret={ret}")
ret = rknn.init_runtime(target=self.target)
if ret != 0:
raise RuntimeError(f"Init RKNN runtime failed! target={self.target}, ret={ret}")
return rknn
def encoder_run(self, img, ori_shape):
img_input, new_size, pad_info, scale = img_preprocess(img, ori_shape)
outputs = self.encoder_rknn.inference(inputs=[img_input])[0]
return outputs, new_size, pad_info, scale
def decoder_run(self, img_embeds, point_coords, point_labels, new_size, scale, mask_input=None):
point_coords = coords_preprocess(point_coords[None, :, :], ori_shape=None, new_size=new_size, scale=scale)
point_labels = point_labels[None, :].astype(np.float32)
if mask_input is not None:
mask_input = np.array(mask_input).astype(np.float32)
has_mask_input = np.ones(1, dtype=np.float32)
else:
mask_input = np.zeros((1, 1, 112, 112), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)
point_coords = self.pad_to_16bytes(point_coords)
point_labels = self.pad_to_16bytes(point_labels)
outputs = self.decoder_rknn.inference(
inputs=[img_embeds, point_coords, point_labels, mask_input, has_mask_input],
data_format='NCHW'
)
return outputs[0], outputs[1]
def pad_to_16bytes(self, data):
data = np.array(data, dtype=np.float32)
pad_num = (16 - data.nbytes % 16) % 16
if pad_num > 0:
pad_data = np.zeros((pad_num // data.itemsize,), dtype=data.dtype)
if len(data.shape) == 3:
pad_data = pad_data.reshape(1, -1, 2)[:,:data.shape[1],:]
data = np.concatenate([data, pad_data], axis=1)
elif len(data.shape) == 2:
pad_data = pad_data.reshape(1, -1)[:,:data.shape[1]]
data = np.concatenate([data, pad_data], axis=1)
return data
def run(self, img, point_coords, point_labels, ori_shape, mask_input=None):
img_embeds, new_size, pad_info, scale = self.encoder_run(img, ori_shape)
iou_predictions, low_res_masks = self.decoder_run(
img_embeds, point_coords, point_labels, new_size, scale, mask_input
)
return iou_predictions, low_res_masks, new_size, pad_info
def release(self):
if self.encoder_rknn:
self.encoder_rknn.release()
if self.decoder_rknn:
self.decoder_rknn.release()
@app.route('/segment', methods=['POST'])
def segment_image():
global mobilesam_instance
try:
data = request.get_json()
if not data:
return jsonify({"code": -1, "msg": "Missing JSON data"}), 400
required_params = ["image_base64", "point_coords", "point_labels"]
for param in required_params:
if param not in data:
return jsonify({"code": -1, "msg": f"Missing required parameter: {param}"}), 400
img_base64 = data["image_base64"]
try:
img_data = base64.b64decode(img_base64)
img = Image.open(io.BytesIO(img_data))
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
ori_shape = img.shape[:2]
if ori_shape[0] == 0 or ori_shape[1] == 0:
return jsonify({"code": -1, "msg": "Invalid image size"}), 400
except Exception as e:
return jsonify({"code": -1, "msg": f"Decode image failed: {str(e)}"}), 400
try:
point_coords = np.array(data["point_coords"], dtype=np.float32)
point_labels = np.array(data["point_labels"], dtype=np.float32)
if len(point_coords.shape) != 2 or point_coords.shape[1] != 2:
return jsonify({"code": -1, "msg": "point_coords must be (N, 2) format"}), 400
if len(point_labels.shape) != 1:
return jsonify({"code": -1, "msg": "point_labels must be (N,) format"}), 400
if len(point_coords) != len(point_labels):
return jsonify({"code": -1, "msg": "point_coords and point_labels length mismatch"}), 400
if len(point_coords) == 0:
return jsonify({"code": -1, "msg": "point_coords cannot be empty"}), 400
except Exception as e:
return jsonify({"code": -1, "msg": f"Parse point data failed: {str(e)}"}), 400
mask_input = data.get("mask_input")
if mask_input is not None:
try:
mask_input = np.array(mask_input, dtype=np.float32)
except Exception as e:
return jsonify({"code": -1, "msg": f"Invalid mask_input: {str(e)}"}), 400
try:
scores, low_res_masks, new_size, pad_info = mobilesam_instance.run(
img=img,
point_coords=point_coords,
point_labels=point_labels,
ori_shape=ori_shape,
mask_input=mask_input
)
except Exception as e:
return jsonify({"code": -1, "msg": f"Inference failed: {str(e)}"}), 500
masks = postprocess(low_res_masks, ori_shape, new_size, pad_info)
masks = masks > MASK_THRESHOLD
best_mask_idx = np.argmax(scores)
masks = masks[:, best_mask_idx, :, :]
result_base64 = draw(img, masks, point_coords, point_labels)
return jsonify({
"code": 0,
"msg": "success",
"data": {
"result_base64": result_base64,
"confidence": float(np.max(scores))
}
}), 200
except Exception as e:
import traceback
error_msg = f"Server error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return jsonify({"code": -1, "msg": error_msg}), 500
def init_service(encoder_path, decoder_path, target_platform):
global mobilesam_instance
try:
mobilesam_instance = MobileSAM(encoder_path, decoder_path, target_platform)
print("MobileSAM service initialized successfully")
except Exception as e:
print(f"Init service failed: {str(e)}")
raise
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MobileSAM RKNN HTTP Service')
parser.add_argument('--encoder', type=str, required=True, help="MobileSAM encoder RKNN model path")
parser.add_argument('--decoder', type=str, required=True, help="MobileSAM decoder RKNN model path")
parser.add_argument('--target', type=str, required=True, help="Target platform (e.g., rk3588)")
parser.add_argument('--host', type=str, default='0.0.0.0', help="HTTP service host")
parser.add_argument('--port', type=int, default=5000, help="HTTP service port")
args = parser.parse_args()
init_service(args.encoder, args.decoder, args.target)
app.run(host=args.host, port=args.port, debug=False, use_reloader=False)