//! Parallel execution utilities using std::thread::scope
//!
//! This module provides structured parallelism for CPU tensor operations.
//! Work is split evenly among threads once and then joined, which is
//! better suited for predictable linear algebra workloads than work-stealing.
/// Get the number of threads to use for parallel operations.
///
/// Returns 1 on wasm32 targets (no threading support).
/// On other targets, returns the available parallelism or 1 if unavailable.
#[inline]
pub fn num_threads() -> usize {
#[cfg(target_arch = "wasm32")]
{
1
}
#[cfg(not(target_arch = "wasm32"))]
{
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
}
/// Execute a function in parallel over pairs of input/output chunks.
///
/// Useful for operations like softmax where each input row maps to an output row.
///
/// # Arguments
/// * `input` - Input data slice
/// * `output` - Output data slice (must be same length as input)
/// * `chunk_size` - Size of each chunk to process together
/// * `f` - Function receiving (chunk_index, input_chunk, output_chunk)
#[inline]
pub fn parallel_zip_chunks_mut<T, U, F>(input: &[T], output: &mut [U], chunk_size: usize, f: F)
where
T: Sync,
U: Send,
F: Fn(usize, &[T], &mut [U]) + Send + Sync,
{
assert_eq!(
input.len(),
output.len(),
"Input and output must have same length"
);
if input.is_empty() || chunk_size == 0 {
return;
}
let n_threads = num_threads();
let total_chunks = input.len().div_ceil(chunk_size);
// If single-threaded or very small workload, run sequentially
if n_threads == 1 || total_chunks <= 1 {
for (i, (in_chunk, out_chunk)) in input
.chunks(chunk_size)
.zip(output.chunks_mut(chunk_size))
.enumerate()
{
f(i, in_chunk, out_chunk);
}
return;
}
// Distribute chunks evenly among threads
let chunks_per_thread = total_chunks.div_ceil(n_threads);
let elements_per_thread = chunks_per_thread * chunk_size;
std::thread::scope(|scope| {
let mut remaining_in = input;
let mut remaining_out = output;
let mut chunk_offset = 0;
for thread_id in 0..n_threads {
if remaining_in.is_empty() {
break;
}
let this_size = if thread_id == n_threads - 1 {
remaining_in.len()
} else {
elements_per_thread.min(remaining_in.len())
};
let (thread_in, rest_in) = remaining_in.split_at(this_size);
let (thread_out, rest_out) = remaining_out.split_at_mut(this_size);
remaining_in = rest_in;
remaining_out = rest_out;
let current_chunk_offset = chunk_offset;
chunk_offset += this_size.div_ceil(chunk_size);
let f_ref = &f;
scope.spawn(move || {
for (i, (in_chunk, out_chunk)) in thread_in
.chunks(chunk_size)
.zip(thread_out.chunks_mut(chunk_size))
.enumerate()
{
f_ref(current_chunk_offset + i, in_chunk, out_chunk);
}
});
}
});
}