"""Parallel processing tools to help speed up certain tasks like data
preprocessing.
Authors
* Sylvain de Langen 2023
"""
import itertools
import multiprocessing
from collections import deque
from concurrent.futures import Executor, ProcessPoolExecutor
from threading import Condition
from typing import Any, Callable, Iterable, Optional
from tqdm.auto import tqdm
def _chunk_process_wrapper(fn, chunk):
return list(map(fn, chunk))
class CancelFuturesOnExit:
"""Context manager that .cancel()s all elements of a list upon exit.
This is used to abort futures faster when raising an exception."""
def __init__(self, future_list):
self.future_list = future_list
def __enter__(self):
pass
def __exit__(self, _type, _value, _traceback):
for future in self.future_list:
future.cancel()
class _ParallelMapper:
"""Internal class for `parallel_map`, arguments match the constructor's."""
def __init__(
self,
fn: Callable[[Any], Any],
source: Iterable[Any],
process_count: int,
chunk_size: int,
queue_size: int,
executor: Optional[Executor],
progress_bar: bool,
progress_bar_kwargs: dict,
):
self.future_chunks = deque()
self.cv = Condition()
self.just_finished_count = 0
"""Number of jobs that were just done processing, guarded by
`self.cv`."""
self.remote_exception = None
"""Set by a worker when it encounters an exception, guarded by
`self.cv`."""
self.fn = fn
self.source = source
self.process_count = process_count
self.chunk_size = chunk_size
self.queue_size = queue_size
self.executor = executor
self.known_len = len(source) if hasattr(source, "__len__") else None
self.source_it = iter(source)
self.depleted_source = False
if progress_bar:
tqdm_final_kwargs = {"total": self.known_len}
tqdm_final_kwargs.update(progress_bar_kwargs)
self.pbar = tqdm(**tqdm_final_kwargs)
else:
self.pbar = None
def run(self):
"""Spins up an executor (if none were provided), then yields all
processed chunks in order."""
with CancelFuturesOnExit(self.future_chunks):
if self.executor is not None:
yield from self._map_all()
else:
with ProcessPoolExecutor(
max_workers=self.process_count
) as pool:
self.executor = pool
yield from self._map_all()
def _bump_processed_count(self, future):
"""Notifies the main thread of the finished job, bumping the number of
jobs it should requeue. Updates the progress bar based on the returned
chunk length.
Arguments
---------
future: concurrent.futures.Future
A future holding a processed chunk (of type `list`).
"""
if future.cancelled():
return
future_exception = future.exception()
with self.cv:
if future_exception is not None:
self.remote_exception = future_exception
self.just_finished_count += 1
self.cv.notify()
if future_exception is None:
if self.pbar is not None:
self.pbar.update(len(future.result()))
def _enqueue_job(self):
"""Pulls a chunk from the source iterable and submits it to the
pool; must be run from the main thread.
Returns
-------
`True` if any job was submitted (that is, if there was any chunk
left to process), `False` otherwise.
"""
chunk = list(itertools.islice(self.source_it, self.chunk_size))
if len(chunk) == 0:
self.depleted_source = True
return False
future = self.executor.submit(_chunk_process_wrapper, self.fn, chunk)
future.add_done_callback(self._bump_processed_count)
self.future_chunks.append(future)
return True
def _map_all(self):
"""Performs all the parallel mapping logic.
Arguments
---------
executor: concurrent.futures.Executor
The executor to `.invoke` all jobs on. The executor is NOT shut down
at the end of processing.
"""
for _ in range(self.queue_size):
if not self._enqueue_job():
break
while (not self.depleted_source) or (len(self.future_chunks) != 0):
with self.cv:
while self.just_finished_count == 0:
self.cv.wait()
if self.remote_exception is not None:
raise self.remote_exception
to_queue_count = self.just_finished_count
self.just_finished_count = 0
for _ in range(to_queue_count):
if not self._enqueue_job():
break
while len(self.future_chunks) != 0 and self.future_chunks[0].done():
yield from self.future_chunks.popleft().result()
if self.pbar is not None:
self.pbar.close()
def parallel_map(
fn: Callable[[Any], Any],
source: Iterable[Any],
process_count: int = multiprocessing.cpu_count(),
chunk_size: int = 8,
queue_size: int = 128,
executor: Optional[Executor] = None,
progress_bar: bool = True,
progress_bar_kwargs: dict = {"smoothing": 0.02},
):
"""Maps iterable items with a function, processing chunks of items in
parallel with multiple processes and displaying progress with tqdm.
Processed elements will always be returned in the original, correct order.
Unlike `ProcessPoolExecutor.map`, elements are produced AND consumed lazily.
Arguments
---------
fn
The function that is called for every element in the source list.
The output is an iterator over the source list after fn(elem) is called.
source: Iterable
Iterator whose elements are passed through the mapping function.
process_count: int
The number of processes to spawn. Ignored if a custom executor is
provided.
For CPU-bound tasks, it is generally not useful to exceed logical core
count.
For IO-bound tasks, it may make sense to as to limit the amount of time
spent in iowait.
chunk_size: int
How many elements are fed to the worker processes at once. A value of 8
is generally fine. Low values may increase overhead and reduce CPU
occupancy.
queue_size: int
Number of chunks to be waited for on the main process at a time.
Low values increase the chance of the queue being starved, forcing
workers to idle.
Very high values may cause high memory usage, especially if the source
iterable yields large objects.
executor: Optional[Executor]
Allows providing an existing executor (preferably a
ProcessPoolExecutor). If None (the default), a process pool will be
spawned for this mapping task and will be shut down after.
progress_bar: bool
Whether to show a tqdm progress bar.
progress_bar_kwargs: dict
A dict of keyword arguments that is forwarded to tqdm when
`progress_bar == True`. Allows overriding the defaults or e.g.
specifying `total` when it cannot be inferred from the source iterable.
"""
mapper = _ParallelMapper(
fn,
source,
process_count,
chunk_size,
queue_size,
executor,
progress_bar,
progress_bar_kwargs,
)
yield from mapper.run()