use candle_core::Tensor;
#[derive(Debug, Clone)]
pub struct KvCache {
key: TensorCache,
value: TensorCache,
}
impl KvCache {
pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
Self {
key: TensorCache::new(concat_dim, max_seq_len),
value: TensorCache::new(concat_dim, max_seq_len),
}
}
pub fn k(&self) -> candle_core::Result<Option<Tensor>> {
self.key.current_data()
}
pub fn k_cache(&self) -> &TensorCache {
&self.key
}
pub fn k_cache_mut(&mut self) -> &mut TensorCache {
&mut self.key
}
pub fn v(&self) -> candle_core::Result<Option<Tensor>> {
self.value.current_data()
}
pub fn v_cache(&self) -> &TensorCache {
&self.value
}
pub fn v_cache_mut(&mut self) -> &mut TensorCache {
&mut self.value
}
pub fn reset(&mut self) {
self.key.reset();
self.value.reset();
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
Ok((self.key.append(k)?, self.value.append(v)?))
}
}
#[derive(Debug, Clone)]
pub struct TensorCache {
all_data: Option<Tensor>,
start_offset: usize,
current_seq_len: usize,
allocated_seq_len: usize,
concat_dim: usize,
max_seq_len: usize,
}
impl TensorCache {
pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
start_offset: 0,
current_seq_len: 0,
allocated_seq_len: 0,
concat_dim,
max_seq_len,
}
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> candle_core::Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(
d.narrow(self.concat_dim, self.start_offset, self.current_seq_len)?
.contiguous()?,
),
};
Ok(data)
}
pub fn reset(&mut self) {
self.all_data = None;
self.current_seq_len = 0;
self.allocated_seq_len = 0;
self.start_offset = 0;
}
pub fn append(&mut self, v: &Tensor) -> candle_core::Result<Tensor> {
let v = v.contiguous()?;
let seq_len = v.dim(self.concat_dim)?;
let current_allocated_size = self.allocated_seq_len;
let size_required_for_append = self.current_seq_len + seq_len;
let data = if size_required_for_append > self.max_seq_len {
let max_seq_len = self.max_seq_len;
let new_start = size_required_for_append - max_seq_len;
let mut tensors = Vec::new();
if let Some(all_data) = self.all_data.as_ref() {
tensors.push(all_data.narrow(
self.concat_dim,
new_start,
current_allocated_size - new_start,
)?);
}
tensors.push(v.clone());
let all_data = Tensor::cat(&tensors, self.concat_dim)?;
let all_data_len = all_data.dim(self.concat_dim)?;
self.all_data =
Some(all_data.narrow(self.concat_dim, all_data_len - max_seq_len, max_seq_len)?);
self.current_seq_len = max_seq_len;
self.allocated_seq_len = self.max_seq_len;
all_data
} else {
if size_required_for_append > current_allocated_size {
let next_power_of_two = size_required_for_append.next_power_of_two();
let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
tracing::trace!(
"Extending Tensor cache from {current_allocated_size} to {new_cache_max_seq_len}"
);
let mut tensors = Vec::new();
if let Some(v) = self.all_data() {
tensors.push(v.clone());
}
let mut shape = v.shape().dims().to_vec();
shape[self.concat_dim] = new_cache_max_seq_len - current_allocated_size;
tensors.push(Tensor::zeros(shape.as_slice(), v.dtype(), v.device())?);
let new_cache = Tensor::cat(&tensors, self.concat_dim)?;
self.all_data = Some(new_cache);
self.allocated_seq_len = new_cache_max_seq_len;
}
self.all_data
.as_mut()
.unwrap()
.slice_set(&v, self.concat_dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
self.current_data()?.unwrap()
};
data.contiguous()
}
}