use std::error::Error;
use fusor_gguf::GgufValue;
use nom::{
AsChar, IResult, Input, Parser as _,
branch::alt,
bytes::complete::tag,
character::complete::{char, digit1, multispace0, one_of},
combinator::{map, opt, recognize},
error::ErrorKind,
multi::separated_list0,
sequence::delimited,
};
pub(crate) fn parse_key_val<T>(
s: &str,
) -> Result<(T, GgufValue), Box<dyn Error + Send + Sync + 'static>>
where
T: std::str::FromStr,
T::Err: Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((
s[..pos].parse()?,
parse_gguf_value(&s[pos + 1..])
.map_err(|e| format!("invalid KEY=value: {e}"))?
.1,
))
}
fn parse_int(input: &str) -> IResult<&str, GgufValue> {
let (input, (num_str, suffix)) = (
recognize((opt(one_of("+-")), digit1)),
alt((
tag("u8"),
tag("i8"),
tag("u16"),
tag("i16"),
tag("u32"),
tag("i32"),
tag("u64"),
tag("i64"),
)),
)
.parse(input)?;
let value = match suffix {
"u8" => GgufValue::U8(num_str.parse().unwrap()),
"i8" => GgufValue::I8(num_str.parse().unwrap()),
"u16" => GgufValue::U16(num_str.parse().unwrap()),
"i16" => GgufValue::I16(num_str.parse().unwrap()),
"u32" => GgufValue::U32(num_str.parse().unwrap()),
"i32" => GgufValue::I32(num_str.parse().unwrap()),
"u64" => GgufValue::U64(num_str.parse().unwrap()),
"i64" => GgufValue::I64(num_str.parse().unwrap()),
_ => unreachable!(),
};
Ok((input, value))
}
fn parse_float(input: &str) -> IResult<&str, GgufValue> {
let (input, (num_str, suffix)) = (
recognize((opt(one_of("+-")), digit1, opt((char('.'), digit1)))),
opt(alt((tag("f32"), tag("f64")))),
)
.parse(input)?;
let value = match suffix {
Some("f32") => GgufValue::F32(num_str.parse().unwrap()),
Some("f64") => GgufValue::F64(num_str.parse().unwrap()),
None => {
if num_str.contains('.') {
GgufValue::F64(num_str.parse().unwrap())
} else {
GgufValue::I32(num_str.parse().unwrap())
}
}
_ => unreachable!(),
};
Ok((input, value))
}
fn parse_bool(input: &str) -> IResult<&str, GgufValue> {
alt((
map(tag("true"), |_| GgufValue::Bool(true)),
map(tag("false"), |_| GgufValue::Bool(false)),
))
.parse(input)
}
fn parse_string(input: &str) -> IResult<&str, GgufValue> {
let (input, s) = input.split_at_position1_complete(
|item| !item.is_alphanum() && item != '.',
ErrorKind::AlphaNumeric,
)?;
Ok((input, GgufValue::String(s.to_string().into_boxed_str())))
}
fn parse_array(input: &str) -> IResult<&str, GgufValue> {
let (input, elems) = delimited(
char('['),
separated_list0(
delimited(multispace0, char(','), multispace0),
parse_gguf_value,
),
char(']'),
)
.parse(input)?;
Ok((input, GgufValue::Array(elems.into_boxed_slice())))
}
fn parse_gguf_value(input: &str) -> IResult<&str, GgufValue> {
delimited(
multispace0,
alt((
parse_array,
parse_bool,
parse_int,
parse_float,
parse_string,
)),
multispace0,
)
.parse(input)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_values() {
assert_eq!(parse_gguf_value("0").unwrap().1, GgufValue::I32(0));
assert_eq!(parse_gguf_value("-100").unwrap().1, GgufValue::I32(-100));
assert_eq!(parse_gguf_value("0u8").unwrap().1, GgufValue::U8(0));
assert_eq!(parse_gguf_value("1.1").unwrap().1, GgufValue::F64(1.1));
assert_eq!(parse_gguf_value("1.1f32").unwrap().1, GgufValue::F32(1.1));
assert_eq!(parse_gguf_value("true").unwrap().1, GgufValue::Bool(true));
assert_eq!(
parse_gguf_value("hello").unwrap().1,
GgufValue::String("hello".into())
);
assert_eq!(
parse_gguf_value("hello.world").unwrap().1,
GgufValue::String("hello.world".into())
);
if let GgufValue::Array(vals) = parse_gguf_value("[hello, world]").unwrap().1 {
assert_eq!(vals[0], GgufValue::String("hello".into()));
assert_eq!(vals[1], GgufValue::String("world".into()));
} else {
panic!("Expected array");
}
}
}