@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
+import dataclasses
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
@@ -9,6 +10,7 @@
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
+from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@@ -19,6 +21,9 @@
KVConnectorRole,
KVConnectorWorkerMetadata,
)
+from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
+ LMCacheKVEvents,
+)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
@@ -122,6 +127,91 @@
)
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
+class ConnectorKVCacheEvents(KVCacheEvent):
+ """KV cache events grouped by connector type."""
+
+ connector_name: str
+ events: list[KVCacheEvent]
+
+
+class MultiConnectorKVEvents(KVConnectorKVEvents):
+ """KV events for multiple connectors."""
+
+ def __init__(self, data: dict[str, KVConnectorKVEvents] | None = None):
+ self._data = data or {}
+
+ def _create_kv_events_for_connector(
+ self, connector_name: str
+ ) -> KVConnectorKVEvents | None:
+ """
+ Factory method to create appropriate KVConnectorKVEvents instance.
+ Args:
+ connector_name: Name of the connector requiring event tracking
+ Returns:
+ KVConnectorKVEvents instance appropriate for the connector type or
+ None if none available
+ """
+ if connector_name == "LMCacheConnectorV1":
+ return LMCacheKVEvents(num_workers=1)
+ else:
+ return None
+
+ def add_events(self, events: list[KVCacheEvent]) -> None:
+ for event in events:
+ assert isinstance(event, ConnectorKVCacheEvents)
+
+ connector_name = event.connector_name
+ if connector_name not in self._data:
+ if new_connector_events := self._create_kv_events_for_connector(
+ connector_name
+ ):
+ self._data[connector_name] = new_connector_events
+ self._data[connector_name].add_events(event.events)
+ # Continue to avoid incrementing workers as already set it in
+ # _create_kv_events_for_connector
+ continue
+ else:
+ logger.error(
+ "Unable to process events for connector "
+ "[%s] because it does not have an "
+ "events class to handle the events.",
+ connector_name,
+ )
+ continue
+
+ self._data[connector_name].add_events(event.events)
+ self._data[connector_name].increment_workers(1)
+
+ def increment_workers(self, count: int = 1) -> None:
+ pass
+
+ def get_all_events(self) -> list[KVCacheEvent]:
+ result: list[KVCacheEvent] = []
+ for connector_name, kv_events in self._data.items():
+ if connector_events := kv_events.get_all_events():
+ result.append(
+ ConnectorKVCacheEvents(
+ connector_name=connector_name,
+ events=connector_events,
+ )
+ )
+ return result
+
+ def get_connector_events(self, connector_name: str) -> KVConnectorKVEvents:
+ return self._data[connector_name]
+
+ """ The following methods are not implemented for MultiConnectorKVEvents because
+ the specific connector `KVConnectorKVEvents` methods are called instead.
+ """
+
+ def aggregate(self) -> KVConnectorKVEvents:
+ raise NotImplementedError
+
+ def get_number_of_workers(self) -> int:
+ raise NotImplementedError
+
+ def clear_events(self) -> None:
+ raise NotImplementedError
class MultiConnector(KVConnectorBase_V1):
"""
@@ -342,6 +432,21 @@
# multiple connectors, handling the case where only a subset of the
# requested connectors implements the 'get_kv_connector_kv_cache_events'
# WIP: https://github.com/vllm-project/vllm/pull/31811
+ def get_kv_connector_kv_cache_events(self) -> KVConnectorKVEvents | None:
+ """
+ Get KV connector events grouped by connector type.
+ Returns None if no connectors have events.
+ """
+ events_by_connector: dict[str, KVConnectorKVEvents] = {}
+
+ for c in self._connectors:
+ if connector_events := c.get_kv_connector_kv_cache_events():
+ events_by_connector[c.__class__.__name__] = connector_events
+
+ if not events_by_connector:
+ return None
+
+ return MultiConnectorKVEvents(data=events_by_connector)
#
# Scheduler-side methods
@@ -401,7 +506,7 @@
MultiKVConnectorWorkerMetadata,
)
multi_connector_worker_meta = connector_output.kv_connector_worker_meta
-
+ original_kv_cache_events = connector_output.kv_cache_events
try:
for i, c in enumerate(self._connectors):
if multi_connector_worker_meta is not None:
@@ -409,7 +514,24 @@
connector_output.kv_connector_worker_meta = (
multi_connector_worker_meta.metadata[i]
)
- c.update_connector_output(connector_output)
+
+ kv_cache_events_per_connector = None
+ if original_kv_cache_events is not None and isinstance(
+ original_kv_cache_events, MultiConnectorKVEvents
+ ):
+ try:
+ kv_cache_events_per_connector = (
+ original_kv_cache_events.get_connector_events(
+ c.__class__.__name__
+ )
+ )
+ except KeyError:
+ kv_cache_events_per_connector = None
+
+ connector_output_per_connector = dataclasses.replace(
+ connector_output, kv_cache_events=kv_cache_events_per_connector
+ )
+ c.update_connector_output(connector_output_per_connector)
finally:
# restore kv_connector_worker_meta
connector_output.kv_connector_worker_meta = multi_connector_worker_meta