import os
import time
from contextlib import contextmanager
from threading import Thread
from typing import Any
from uuid import uuid4
import psutil
import zmq
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
create_zmq_socket,
format_zmq_address,
)
logger = get_logger(__name__)
TQ_METRICS_COLLECT_INTERVAL = int(os.environ.get("TQ_METRICS_COLLECT_INTERVAL", 10))
TQ_METRICS_STORAGE_TIMEOUT = int(os.environ.get("TQ_METRICS_STORAGE_TIMEOUT", 5))
class TQMetricsExporter:
"""Prometheus metrics exporter for TransferQueue.
Exposes an HTTP ``/metrics`` endpoint for Prometheus scraping and periodically
updates Prometheus gauges from a *snapshot* dict pushed by the controller.
**Decoupling strategy** — the exporter never accesses controller internals
directly. Instead, the controller calls ``update_controller_snapshot()``
periodically (from its own thread) to hand over a plain ``dict`` that the
background collection thread reads. This eliminates:
* Lock contention between the metrics thread and controller request threads.
* Risk of reading tensors that are being resized concurrently.
Storage-unit metrics are still collected via ZMQ (already process-isolated).
Lifecycle:
1. Created by ``TransferQueueController.start_metrics()`` when metrics are enabled.
2. ``start()`` launches the HTTP server and a background collection thread.
3. The controller calls ``update_controller_snapshot()`` on its own cadence.
4. The collection thread calls ``collect_controller_metrics`` (reads snapshot)
and ``collect_storage_metrics`` every ``TQ_METRICS_COLLECT_INTERVAL`` seconds.
Environment variables:
TQ_METRICS_COLLECT_INTERVAL Collection interval in seconds (default 10)
TQ_METRICS_STORAGE_TIMEOUT ZMQ timeout for storage queries (default 5s)
"""
def __init__(self, role: str = "controller"):
self._start_time = time.time()
self._process = psutil.Process()
self._role = role
self._storage_unit_infos: dict[str, ZMQServerInfo] = {}
self._zmq_ctx: zmq.Context | None = None
self._zmq_sockets: dict[str, zmq.Socket] = {}
self._known_partition_ids: set[str] = set()
self._known_production_labels: set[tuple[str, str]] = set()
self._known_consumption_labels: set[tuple[str, str]] = set()
self._metrics_endpoint: str = ""
self._controller_snapshot: dict[str, Any] = {}
self.registry = CollectorRegistry()
self._define_metrics()
@property
def endpoint(self) -> str:
"""The metrics HTTP endpoint address in ``host:port`` format."""
return self._metrics_endpoint
def _define_metrics(self) -> None:
r = self.registry
role = self._role
self.request_duration = Histogram(
f"tq_{role}_request_duration_seconds",
f"{role.capitalize()} request processing duration",
["op_type"],
buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 5.0],
registry=r,
)
self.request_total = Counter(
f"tq_{role}_request",
f"Total number of requests processed by the {role}",
["op_type"],
registry=r,
)
self.request_errors_total = Counter(
f"tq_{role}_request_errors",
f"Total number of request errors in the {role}",
["op_type"],
registry=r,
)
self.request_samples_total = Counter(
f"tq_{role}_request_samples",
f"Total number of samples processed per operation type in the {role}",
["op_type"],
registry=r,
)
if role == "controller":
self._define_controller_metrics()
def _define_controller_metrics(self) -> None:
r = self.registry
self.controller_uptime = Gauge("tq_controller_uptime_seconds", "Controller uptime in seconds", registry=r)
self.controller_memory_rss = Gauge(
"tq_controller_memory_rss_bytes", "Controller process RSS memory in bytes", registry=r
)
self.partitions_total = Gauge("tq_partitions_total", "Total number of active partitions", registry=r)
self.partition_samples = Gauge(
"tq_partition_samples_total", "Number of active samples in a partition", ["partition_id"], registry=r
)
self.partition_production_progress = Gauge(
"tq_partition_production_progress",
"Production progress ratio (0.0-1.0)",
["partition_id", "task_name"],
registry=r,
)
self.partition_consumption_progress = Gauge(
"tq_partition_consumption_progress",
"Consumption progress ratio (0.0-1.0)",
["partition_id", "task_name"],
registry=r,
)
self.global_index_allocated = Gauge(
"tq_global_index_allocated_total", "Total allocated global indexes", registry=r
)
self.global_index_reusable = Gauge(
"tq_global_index_reusable_total", "Number of reusable global indexes", registry=r
)
self.storage_capacity = Gauge(
"tq_storage_capacity_total", "Storage unit capacity (max keys)", ["storage_unit_id"], registry=r
)
self.storage_active_keys = Gauge(
"tq_storage_active_keys_total", "Active keys in storage unit", ["storage_unit_id"], registry=r
)
self.storage_utilization = Gauge(
"tq_storage_utilization_ratio",
"Storage utilization ratio (active/capacity)",
["storage_unit_id"],
registry=r,
)
self.storage_memory_rss = Gauge(
"tq_storage_memory_rss_bytes", "Storage unit process RSS memory", ["storage_unit_id"], registry=r
)
self.storage_request_ops = Gauge(
"tq_storage_request_ops",
"Total requests processed by storage unit",
["storage_unit_id", "op_type"],
registry=r,
)
self.storage_request_latency_avg = Gauge(
"tq_storage_request_latency_avg",
"Average request latency in storage unit (seconds)",
["storage_unit_id", "op_type"],
registry=r,
)
self.storage_request_latency_p50 = Gauge(
"tq_storage_request_latency_p50",
"P50 request latency in storage unit (seconds)",
["storage_unit_id", "op_type"],
registry=r,
)
self.storage_request_latency_p99 = Gauge(
"tq_storage_request_latency_p99",
"P99 request latency in storage unit (seconds)",
["storage_unit_id", "op_type"],
registry=r,
)
@contextmanager
def measure(self, op_type: str):
"""Context manager that records request count and latency for *op_type*.
Prometheus errors are caught and logged so that metrics failures never
block or crash the controller's request processing loop.
Usage::
with metrics.measure("GET_META"):
result = self.get_metadata(...)
"""
try:
self.request_total.labels(op_type=op_type).inc()
except Exception:
logger.debug(f"Metrics: failed to increment counter for {op_type}", exc_info=True)
start = time.perf_counter()
try:
yield
except Exception:
try:
self.request_errors_total.labels(op_type=op_type).inc()
except Exception:
logger.debug(f"Metrics: failed to increment error counter for {op_type}", exc_info=True)
raise
finally:
elapsed = time.perf_counter() - start
try:
self.request_duration.labels(op_type=op_type).observe(elapsed)
except Exception:
logger.debug(f"Metrics: failed to observe duration for {op_type}", exc_info=True)
def record_samples(self, op_type: str, count: int) -> None:
"""Record the number of samples processed in a single operation."""
try:
self.request_samples_total.labels(op_type=op_type).inc(count)
except Exception:
logger.debug(f"Metrics: failed to record samples for {op_type}", exc_info=True)
def register_storage_units(self, storage_unit_infos: dict[str, ZMQServerInfo]) -> None:
"""Register SimpleStorageUnit ZMQ endpoints for metrics collection."""
self._storage_unit_infos.update(storage_unit_infos)
logger.info(f"Metrics exporter registered {len(storage_unit_infos)} storage units")
def update_controller_snapshot(self, snapshot: dict[str, Any]) -> None:
"""Replace the controller metrics snapshot (called from the controller thread).
The snapshot is a plain dict with no references to live controller objects,
so the metrics thread can read it without risk of concurrent mutation.
"""
self._controller_snapshot = snapshot
def collect_controller_metrics(self) -> None:
"""Update Prometheus gauges from the latest controller snapshot."""
self.controller_uptime.set(time.time() - self._start_time)
try:
self.controller_memory_rss.set(self._process.memory_info().rss)
except Exception:
pass
snapshot = self._controller_snapshot
partitions = snapshot.get("partitions", {})
current_pids = set(partitions.keys())
current_consumption_labels: set[tuple[str, str]] = set()
self.partitions_total.set(len(current_pids))
current_production_labels: set[tuple[str, str]] = set()
for pid, stats in partitions.items():
self.partition_samples.labels(partition_id=pid).set(stats["total_samples_num"])
for task_name, pstats in stats.get("production_statistics", {}).items():
self.partition_production_progress.labels(partition_id=pid, task_name=task_name).set(
pstats.get("production_progress", 0)
)
current_production_labels.add((pid, task_name))
for task_name, cstats in stats.get("consumption_statistics", {}).items():
self.partition_consumption_progress.labels(partition_id=pid, task_name=task_name).set(
cstats.get("consumption_progress", 0)
)
current_consumption_labels.add((pid, task_name))
for stale_pid in self._known_partition_ids - current_pids:
try:
self.partition_samples.remove(stale_pid)
except (KeyError, ValueError):
pass
for stale_pair in self._known_production_labels - current_production_labels:
try:
self.partition_production_progress.remove(*stale_pair)
except (KeyError, ValueError):
pass
for stale_pair in self._known_consumption_labels - current_consumption_labels:
try:
self.partition_consumption_progress.remove(*stale_pair)
except (KeyError, ValueError):
pass
self._known_partition_ids = current_pids
self._known_production_labels = current_production_labels
self._known_consumption_labels = current_consumption_labels
self.global_index_allocated.set(snapshot.get("global_index_allocated", 0))
self.global_index_reusable.set(snapshot.get("global_index_reusable", 0))
def collect_storage_metrics(self) -> None:
"""Query each registered SimpleStorageUnit for metrics via ZMQ."""
if not self._storage_unit_infos:
return
storage_snapshot = list(self._storage_unit_infos.items())
for su_id, su_info in storage_snapshot:
try:
metrics = self._query_storage_unit(su_info, su_id)
if metrics is None:
continue
label = metrics.get("storage_unit_id", su_id)
capacity = metrics.get("capacity", 0)
active = metrics.get("active_keys", 0)
self.storage_capacity.labels(storage_unit_id=label).set(capacity)
self.storage_active_keys.labels(storage_unit_id=label).set(active)
self.storage_utilization.labels(storage_unit_id=label).set(active / capacity if capacity > 0 else 0.0)
self.storage_memory_rss.labels(storage_unit_id=label).set(metrics.get("process_rss_bytes", 0))
for op_type, op_data in metrics.get("op_stats", {}).items():
self.storage_request_ops.labels(storage_unit_id=label, op_type=op_type).set(
op_data.get("request_count", 0)
)
self.storage_request_latency_avg.labels(storage_unit_id=label, op_type=op_type).set(
op_data.get("latency_avg", 0)
)
self.storage_request_latency_p50.labels(storage_unit_id=label, op_type=op_type).set(
op_data.get("latency_p50", 0)
)
self.storage_request_latency_p99.labels(storage_unit_id=label, op_type=op_type).set(
op_data.get("latency_p99", 0)
)
except Exception as e:
logger.warning(f"Failed to collect metrics from storage unit {su_id}: {e}")
def _get_or_create_socket(self, su_id: str, su_info: ZMQServerInfo) -> zmq.Socket:
"""Return a cached ZMQ DEALER socket for *su_id*, creating one if needed."""
if self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
sock = self._zmq_sockets.get(su_id)
if sock is not None and not sock.closed:
return sock
identity = f"metrics_collector_{uuid4().hex[:8]}".encode()
sock = create_zmq_socket(self._zmq_ctx, zmq.DEALER, su_info.ip, identity)
timeout_ms = TQ_METRICS_STORAGE_TIMEOUT * 1000
address = format_zmq_address(su_info.ip, su_info.ports["put_get_socket"])
sock.connect(address)
sock.setsockopt(zmq.RCVTIMEO, timeout_ms)
sock.setsockopt(zmq.SNDTIMEO, timeout_ms)
self._zmq_sockets[su_id] = sock
return sock
def _query_storage_unit(self, su_info: ZMQServerInfo, su_id: str) -> dict[str, Any] | None:
"""Send a synchronous GET_METRICS request to a single storage unit."""
try:
sock = self._get_or_create_socket(su_id, su_info)
request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_METRICS,
sender_id="metrics_collector",
body={},
)
sock.send_multipart(request_msg.serialize())
response_frames = sock.recv_multipart(copy=False)
response_msg = ZMQMessage.deserialize(response_frames)
if response_msg.request_type == ZMQRequestType.METRICS_RESPONSE:
return response_msg.body
return None
except zmq.error.Again:
logger.debug(f"Timeout querying metrics from {su_id}")
return None
except Exception as e:
logger.warning(f"Error querying metrics from {su_id}: {e}")
sock = self._zmq_sockets.pop(su_id, None)
if sock and not sock.closed:
sock.close(linger=0)
return None
def start(self, node_ip: str = "0.0.0.0", port: int = 0) -> str:
"""Start the HTTP /metrics server and the background collection thread.
When *port* is ``0`` (the default), the OS assigns a free port
automatically — the actual port is read back from the server socket.
Args:
node_ip: The IP address of the node hosting the process.
port: HTTP port for the /metrics endpoint.
Returns:
The metrics endpoint address in ``host:port`` format.
"""
from prometheus_client import start_http_server
httpd, _thread = start_http_server(port=port, addr=node_ip, registry=self.registry)
actual_port = httpd.server_address[1]
self._metrics_endpoint = f"{node_ip}:{actual_port}"
logger.info(f"TQ Metrics HTTP server started on {self._metrics_endpoint}")
if self._role == "controller":
self._collect_thread = Thread(
target=self._collect_loop,
name="TQMetricsCollectorThread",
daemon=True,
)
self._collect_thread.start()
return self._metrics_endpoint
def _collect_loop(self) -> None:
"""Background loop that periodically collects controller and storage metrics."""
while True:
try:
self.collect_controller_metrics()
self.collect_storage_metrics()
except Exception as e:
logger.error(f"Metrics collection error: {e}")
time.sleep(TQ_METRICS_COLLECT_INTERVAL)