mod embeddings;
use embeddings::*;
mod attention;
use attention::*;
mod encoder;
use encoder::*;
mod layer;
use layer::*;
mod output_layer;
use output_layer::*;
mod self_attention;
use self_attention::*;
mod self_output;
use self_output::*;
mod intermediate_layer;
use intermediate_layer::*;
pub mod qwen;
pub use qwen::QwenEmbeddingModel;
use fusor::{Device, Result, Tensor, VarBuilder};
use serde::Deserialize;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
Gelu,
Relu,
}
struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
}
impl HiddenActLayer {
fn new(act: HiddenAct) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
Self { act, span }
}
fn forward(&self, xs: &Tensor<3, f32>) -> Tensor<3, f32> {
let _enter = self.span.enter();
match self.act {
HiddenAct::Gelu => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
#[default]
Absolute,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
hidden_act: HiddenAct,
max_position_embeddings: usize,
type_vocab_size: usize,
initializer_range: f64,
layer_norm_eps: f64,
pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool,
model_type: Option<String>,
}
pub struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
pub(crate) device: Device,
span: tracing::Span,
}
impl BertModel {
pub fn load(device: &Device, vb: &mut VarBuilder, config: &Config) -> Result<Self> {
let (embeddings, encoder) = match (
BertEmbeddings::load(device, vb, config),
BertEncoder::load(device, vb, config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(
device,
&mut vb.pp(format!("{model_type}.embeddings")),
config,
),
BertEncoder::load(
device,
&mut vb.pp(format!("{model_type}.encoder")),
config,
),
) {
(embeddings, encoder)
} else {
return Err(err);
}
} else {
return Err(err);
}
}
};
Ok(Self {
embeddings,
encoder,
device: device.clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(
&self,
input_ids: &Tensor<2, u32>,
token_type_ids: &Tensor<2, u32>,
attention_mask: Option<&Tensor<2, u32>>,
) -> Tensor<3, f32> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids, token_type_ids);
self.encoder.forward(&embedding_output, attention_mask)
}
pub(crate) fn max_seq_len(&self) -> usize {
self.embeddings.max_seq_len()
}
pub(crate) fn embedding_dim(&self) -> usize {
self.embeddings.embedding_dim()
}
}