* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
use super::RecorderError;
use crate::types::xgpu_recorder::v1::{RecordBlock, RecorderFooter, RecorderHeader};
use prost::Message;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MessageType {
Header = 1,
Block = 2,
Footer = 3,
}
pub trait RecordWriter: Send + Sync {
fn write_header(&mut self, header: &RecorderHeader) -> Result<(), RecorderError>;
fn write_block(&mut self, block: &RecordBlock) -> Result<(), RecorderError>;
fn write_footer(&mut self, footer: &RecorderFooter) -> Result<(), RecorderError>;
fn flush(&mut self) -> Result<(), RecorderError>;
fn position(&mut self) -> Result<u64, RecorderError>;
fn sync(&mut self) -> Result<(), RecorderError>;
}
pub trait RecordReader: Send + Sync {
fn read_header(&mut self) -> Result<RecorderHeader, RecorderError>;
fn read_block(&mut self) -> Result<Option<RecordBlock>, RecorderError>;
fn read_footer(&mut self) -> Result<RecorderFooter, RecorderError>;
fn read_footer_from_end(&mut self) -> Result<RecorderFooter, RecorderError>;
fn seek(&mut self, pos: u64) -> Result<u64, RecorderError>;
fn position(&mut self) -> Result<u64, RecorderError>;
fn file_size(&mut self) -> Result<u64, RecorderError>;
}
pub struct ProtobufWriter {
pub writer: BufWriter<File>,
}
impl ProtobufWriter {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, RecorderError> {
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(path)?;
Ok(Self {
writer: BufWriter::new(file),
})
}
fn write_message<M>(&mut self, msg: &M, msg_type: MessageType) -> Result<(), RecorderError>
where
M: prost::Message,
{
let mut buf = Vec::new();
fn fun_name(e: prost::EncodeError) -> RecorderError {
RecorderError::ProtobufEncoding(e)
}
msg.encode(&mut buf).map_err(fun_name)?;
self.writer.write_all(&[msg_type as u8])?;
let len = buf.len() as u32;
self.writer.write_all(&len.to_be_bytes())?;
self.writer.write_all(&buf)?;
if msg_type == MessageType::Footer {
let footerlen = len + 1 + 4 + 4;
self.writer.write_all(&(footerlen.to_be_bytes()))?;
}
Ok(())
}
}
impl RecordWriter for ProtobufWriter {
fn write_header(&mut self, header: &RecorderHeader) -> Result<(), RecorderError> {
self.write_message(header, MessageType::Header)
}
fn write_block(&mut self, block: &RecordBlock) -> Result<(), RecorderError> {
self.write_message(block, MessageType::Block)
}
fn write_footer(&mut self, footer: &RecorderFooter) -> Result<(), RecorderError> {
self.write_message(footer, MessageType::Footer)
}
fn flush(&mut self) -> Result<(), RecorderError> {
self.writer.flush().map_err(RecorderError::Io)
}
fn position(&mut self) -> Result<u64, RecorderError> {
self.writer.stream_position().map_err(RecorderError::Io)
}
fn sync(&mut self) -> Result<(), RecorderError> {
self.writer.get_mut().sync_all().map_err(RecorderError::Io)
}
}
pub struct ProtobufReader {
reader: BufReader<File>,
}
impl ProtobufReader {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, RecorderError> {
let file = File::open(path)?;
Ok(Self {
reader: BufReader::new(file),
})
}
fn read_message<M>(&mut self) -> Result<Option<(MessageType, M)>, RecorderError>
where
M: Message + Default,
{
let mut type_buf = [0u8; 1];
match self.reader.read_exact(&mut type_buf) {
Ok(()) => {}
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(RecorderError::Io(e)),
}
let msg_type = match type_buf[0] {
1 => MessageType::Header,
2 => MessageType::Block,
3 => MessageType::Footer,
t => {
return Err(RecorderError::FormatError {
message: format!("Unknown message type: {}", t),
});
}
};
let mut len_buf = [0u8; 4];
self.reader.read_exact(&mut len_buf)?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > 100 * 1024 * 1024 {
return Err(RecorderError::FormatError {
message: format!("Message too large: {} bytes", len),
});
}
let mut buf = vec![0u8; len];
self.reader.read_exact(&mut buf)?;
let msg = M::decode(&buf[..])
.map_err(|e| -> RecorderError { RecorderError::ProtobufDecoding(e) })?;
Ok(Some((msg_type, msg)))
}
fn read_any_message(&mut self) -> Result<Option<(MessageType, Vec<u8>)>, RecorderError> {
let mut type_buf = [0u8; 1];
match self.reader.read_exact(&mut type_buf) {
Ok(()) => {}
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(RecorderError::Io(e)),
}
let msg_type = match type_buf[0] {
1 => MessageType::Header,
2 => MessageType::Block,
3 => MessageType::Footer,
t => {
return Err(RecorderError::FormatError {
message: format!("Unknown message type: {}", t),
});
}
};
let mut len_buf = [0u8; 4];
self.reader.read_exact(&mut len_buf)?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > 100 * 1024 * 1024 {
return Err(RecorderError::FormatError {
message: format!("Message too large: {} bytes", len),
});
}
let mut buf = vec![0u8; len];
self.reader.read_exact(&mut buf)?;
Ok(Some((msg_type, buf)))
}
}
impl RecordReader for ProtobufReader {
fn read_header(&mut self) -> Result<RecorderHeader, RecorderError> {
match self.read_message::<RecorderHeader>()? {
Some((MessageType::Header, header)) => Ok(header),
Some((msg_type, _)) => Err(RecorderError::FormatError {
message: format!("Expected header, found {:?}", msg_type),
}),
None => Err(RecorderError::FormatError {
message: "Missing file header".to_string(),
}),
}
}
fn read_block(&mut self) -> Result<Option<RecordBlock>, RecorderError> {
match self.read_any_message()? {
Some((MessageType::Block, buf)) => {
let block = RecordBlock::decode(&buf[..])
.map_err(|e| -> RecorderError { RecorderError::ProtobufDecoding(e) })?;
Ok(Some(block))
}
Some((MessageType::Header, _)) => Err(RecorderError::FormatError {
message: "Unexpected header in block stream".to_string(),
}),
Some((MessageType::Footer, _)) | None => Ok(None),
}
}
fn read_footer(&mut self) -> Result<RecorderFooter, RecorderError> {
match self.read_message::<RecorderFooter>()? {
Some((MessageType::Footer, footer)) => Ok(footer),
Some((msg_type, _)) => Err(RecorderError::FormatError {
message: format!("Expected footer, found {:?}", msg_type),
}),
None => Err(RecorderError::FormatError {
message: "Missing file footer".to_string(),
}),
}
}
fn read_footer_from_end(&mut self) -> Result<RecorderFooter, RecorderError> {
let current_pos = self.position()?;
self.reader.seek(SeekFrom::End(-(4_i64)))?;
let mut buf_len = [0u8; 4];
self.reader.read_exact(&mut buf_len)?;
let len = u32::from_be_bytes(buf_len) as usize;
self.reader.seek(SeekFrom::End(-(len as i64)))?;
let footer = self.read_footer()?;
self.reader.seek(SeekFrom::Start(current_pos))?;
Ok(footer)
}
fn seek(&mut self, pos: u64) -> Result<u64, RecorderError> {
self.reader
.seek(SeekFrom::Start(pos))
.map_err(RecorderError::Io)
}
fn position(&mut self) -> Result<u64, RecorderError> {
self.reader.stream_position().map_err(RecorderError::Io)
}
fn file_size(&mut self) -> Result<u64, RecorderError> {
let current_pos = self.position()?;
let size = self.reader.seek(SeekFrom::End(0))?;
self.reader.seek(SeekFrom::Start(current_pos))?;
Ok(size)
}
}