from collections import deque
import serving_cast.stime as stime
logger = stime.get_logger(__name__)
def get_estimated_communication_time(num_bytes: int, bandwidth: int, rate: float):
return num_bytes / (bandwidth * rate)
class Channel(stime.Task):
def __init__(self, bandwidth: int, rate: float):
super().__init__()
self.bytes_list = deque()
self.bandwidth = bandwidth
self.rate = rate
def send_bytes(self, num_bytes: int, callback, *callback_args, **callback_kwargs) -> None:
self.bytes_list.append([num_bytes, callback, callback_args, callback_kwargs])
self.notify()
def process(self):
if len(self.bytes_list) == 0:
self.wait()
while True:
num_bytes, callback, callback_args, callback_kwargs = self.bytes_list.popleft()
estimated_duration = get_estimated_communication_time(num_bytes, self.bandwidth, self.rate)
stime.elapse(estimated_duration)
if callback is not None:
try:
callback(*callback_args, **callback_kwargs)
except Exception as e:
raise RuntimeError(f"Error in callback: {e}") from e
while len(self.bytes_list) == 0:
self.wait()
class CommunicationManager:
def __init__(self, communication_config):
if communication_config.host2device_bandwidth <= 0:
raise ValueError("host2device_bandwidth should be positive")
if communication_config.host2device_rate <= 0 or communication_config.host2device_rate > 1:
raise ValueError("host2device_rate should be in (0, 1]")
if communication_config.device2device_bandwidth <= 0:
raise ValueError("device2device_bandwidth should be positive")
if communication_config.device2device_rate <= 0 or communication_config.device2device_rate > 1:
raise ValueError("device2device_rate should be in (0, 1]")
self.host2device_channel = Channel(
bandwidth=communication_config.host2device_bandwidth,
rate=communication_config.host2device_rate,
)
self.device2device_channel = Channel(
bandwidth=communication_config.device2device_bandwidth,
rate=communication_config.device2device_rate,
)
@staticmethod
def async_send(target_channel, num_bytes: int, callback, *callback_args, **callback_kwargs) -> None:
target_channel.send_bytes(num_bytes, callback, *callback_args, **callback_kwargs)
@staticmethod
def sync_send(target_channel, num_bytes: int) -> None:
current_task = stime.current_task()
def callback():
current_task.notify()
target_channel.send_bytes(num_bytes, callback)
current_task.wait()
def host2device_sync(self, num_bytes: int) -> None:
if not isinstance(num_bytes, int) or num_bytes <= 0:
raise ValueError("num_bytes should be a positive int")
self.sync_send(self.host2device_channel, num_bytes)
def host2device_async(self, num_bytes: int, callback, *callback_args, **callback_kwargs) -> None:
if not isinstance(num_bytes, int) or num_bytes <= 0:
raise ValueError("num_bytes should be a positive int")
self.async_send(
self.host2device_channel,
num_bytes,
callback,
*callback_args,
**callback_kwargs,
)
def device2device_sync(self, num_bytes: int) -> None:
if not isinstance(num_bytes, int) or num_bytes <= 0:
raise ValueError("num_bytes should be a positive int")
self.sync_send(self.device2device_channel, num_bytes)
def device2device_async(self, num_bytes: int, callback, *callback_args, **callback_kwargs) -> None:
if not isinstance(num_bytes, int) or num_bytes <= 0:
raise ValueError("num_bytes should be a positive int")
self.async_send(
self.device2device_channel,
num_bytes,
callback,
*callback_args,
**callback_kwargs,
)