"""
Prestop graceful shutdown for Kubernetes PreStop hook.
Flow:
1. Read NodeManager port from user_config.json (api_config.node_manager_port)
2. POST /node-manager/pause to local NodeManager → endpoints set to PAUSED
3. Read engine_mgmt_addrs from the pause response
4. Poll engine /metrics locally (num_requests_waiting / running) until
all drain to zero, then exit (Pod terminates).
Config is read from CONFIG_PATH or CONFIGMAP_PATH env var (same pattern as probe.py).
"""
import json
import logging
import os
import re
import subprocess
import sys
import time
from motor.common.utils.net import format_address
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
DEFAULT_NM_PORT = 1026
DEFAULT_MAX_WAIT_SECONDS = 15
DEFAULT_POLL_INTERVAL = 3
HTTP_TIMEOUT = 10.0
def get_val_by_key_path(config, key_path):
keys = key_path.split('.')
config_element = config
for key in keys:
if not isinstance(config_element, dict) or key not in config_element:
return None
config_element = config_element[key]
return config_element
CONFIG_SEARCH_PATHS = [
os.environ.get("CONFIGMAP_PATH", ""),
os.environ.get("CONFIG_PATH", ""),
]
def load_config():
"""Load user_config.json — tries configmap mount first, then CONFIG_PATH."""
for base_path in CONFIG_SEARCH_PATHS:
if not base_path:
continue
user_config_path = base_path
if os.path.isdir(user_config_path):
user_config_path = os.path.join(user_config_path, "user_config.json")
if os.path.exists(user_config_path):
try:
with open(user_config_path, "r", encoding="utf-8") as f:
config = json.load(f)
if isinstance(config, dict):
logger.info("Loaded config from: %s", user_config_path)
return config
except Exception as e:
logger.warning("Failed to load config from %s: %s", user_config_path, e)
continue
logger.error("user_config.json not found in any search path")
return None
NM_PORT_PATHS = [
"motor_nodemanger_config.api_config.node_manager_port",
"motor_engine_prefill_config.motor_nodemanger_config.api_config.node_manager_port",
"motor_engine_decode_config.motor_nodemanger_config.api_config.node_manager_port",
]
def get_nm_port(config):
"""Read NodeManager port from config JSON, with default fallback."""
for path in NM_PORT_PATHS:
port = get_val_by_key_path(config, path)
if isinstance(port, int) and 1024 <= port <= 65535:
return port
logger.info("node_manager_port not in config, using default %d", DEFAULT_NM_PORT)
return DEFAULT_NM_PORT
def _http_post_json(url, timeout=10):
"""HTTP POST via curl. Returns (status_code, body_text) or (None, None)."""
try:
result = subprocess.run(
[
"curl",
"-s",
"-w",
"\n%{http_code}",
"-X",
"POST",
url,
"-H",
"Content-Type: application/json",
"-d",
"{}",
"--connect-timeout",
str(timeout),
"--max-time",
str(timeout),
],
capture_output=True,
text=True,
timeout=timeout + 2,
check=False,
)
output = result.stdout.strip()
if not output:
return None, None
lines = output.rsplit("\n", 1)
if len(lines) == 2:
return int(lines[1]), lines[0]
return None, output
except Exception as e:
logger.error("curl POST failed: %s", e)
return None, None
def _http_get_text(url, timeout=5):
"""HTTP GET via curl. Returns response body text or None."""
try:
result = subprocess.run(
["curl", "-s", "-X", "GET", url, "--connect-timeout", str(timeout), "--max-time", str(timeout)],
capture_output=True,
text=True,
timeout=timeout + 2,
check=False,
)
if result.returncode == 0 and result.stdout:
return result.stdout
logger.debug("curl GET failed: returncode=%s, stderr=%s", result.returncode, result.stderr[:200])
except Exception as e:
logger.debug("curl GET failed: %s", e)
return None
def send_pause(node_manager_url, config):
"""POST /node-manager/pause to local NodeManager.
Returns parsed JSON response dict, or None on failure.
"""
url = f"{node_manager_url}/node-manager/pause"
status, body = _http_post_json(url, timeout=HTTP_TIMEOUT)
if status == 200 and body:
try:
result = json.loads(body)
logger.info("Pause response: %s", result)
return result
except json.JSONDecodeError as e:
logger.error("Failed to parse pause response: %s", e)
else:
logger.error("Pause failed: status=%s, body=%s", status, body)
return None
def get_engine_metrics(engine_mgmt_addr):
"""Query engine /metrics locally and sum num_requests_waiting/running."""
url = f"http://{engine_mgmt_addr}/metrics"
text = _http_get_text(url, timeout=10)
if not text:
return None
waiting = 0
running = 0
for line in text.split("\n"):
if line.startswith("#") or not line.strip():
continue
m = re.search(r'\b(\d+(?:\.\d+)?)\s*$', line)
if not m:
continue
val = int(float(m.group(1)))
if "num_requests_waiting" in line:
waiting += val
elif "num_requests_running" in line:
running += val
return {"waiting": waiting, "running": running}
def main():
"""Main prestop function.
Usage: python prestop.py [--max-wait SECONDS] [--poll-interval SECONDS]
"""
import argparse
parser = argparse.ArgumentParser(description="Prestop graceful shutdown")
parser.add_argument(
"--max-wait",
type=int,
default=DEFAULT_MAX_WAIT_SECONDS,
help=f"Max wait seconds (default: {DEFAULT_MAX_WAIT_SECONDS})",
)
parser.add_argument(
"--poll-interval",
type=int,
default=DEFAULT_POLL_INTERVAL,
help=f"Poll interval seconds (default: {DEFAULT_POLL_INTERVAL})",
)
args = parser.parse_args()
config = load_config()
if config is None:
logger.error("Failed to load config, exiting")
sys.exit(0)
nm_port = get_nm_port(config)
pod_ip = os.environ.get("POD_IP", "127.0.0.1")
nm_url = f"http://{format_address(pod_ip, nm_port)}"
logger.info("NodeManager URL: %s", nm_url)
logger.info("Sending pause to NodeManager at %s", nm_url)
response = send_pause(nm_url, config)
if response is None:
logger.warning("Terminate request failed, exiting anyway")
sys.exit(0)
engine_mgmt_addrs = response.get("engine_mgmt_addrs", [])
if not engine_mgmt_addrs:
logger.error("No engine_mgmt_addrs in pause response, exiting")
sys.exit(0)
logger.info("Engine mgmt addresses: %s", engine_mgmt_addrs)
logger.info("Polling engine metrics...")
start_time = time.time()
while time.time() - start_time < args.max_wait:
total_waiting = 0
total_running = 0
all_ok = True
for addr in engine_mgmt_addrs:
metrics = get_engine_metrics(addr)
if metrics is None:
logger.debug("Engine %s unreachable", addr)
all_ok = False
break
total_waiting += metrics["waiting"]
total_running += metrics["running"]
if not all_ok:
logger.warning("Engine unreachable, exiting")
break
total_active = total_waiting + total_running
logger.info("active=%d (waiting=%d, running=%d)", total_active, total_waiting, total_running)
if total_active == 0:
elapsed = time.time() - start_time
logger.info("All requests drained after %.1fs", elapsed)
break
time.sleep(args.poll_interval)
else:
elapsed = time.time() - start_time
logger.info("Timeout after %.1fs, stopping anyway", elapsed)
sys.exit(0)
if __name__ == "__main__":
main()