// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

//! File format and serialization support

use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;

use super::RecorderError;
use crate::types::xgpu_recorder::v1::{RecordBlock, RecorderFooter, RecorderHeader};
use prost::Message;

/// Message type identifier
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MessageType {
    Header = 1,
    Block = 2,
    Footer = 3,
}
/// Record writer trait
pub trait RecordWriter: Send + Sync {
    /// Write file header
    fn write_header(&mut self, header: &RecorderHeader) -> Result<(), RecorderError>;

    /// Write record block
    fn write_block(&mut self, block: &RecordBlock) -> Result<(), RecorderError>;

    /// Write file footer
    fn write_footer(&mut self, footer: &RecorderFooter) -> Result<(), RecorderError>;

    /// Flush buffer
    fn flush(&mut self) -> Result<(), RecorderError>;

    /// Get current write position
    fn position(&mut self) -> Result<u64, RecorderError>;

    /// Sync to disk
    fn sync(&mut self) -> Result<(), RecorderError>;
}

/// Record reader trait
pub trait RecordReader: Send + Sync {
    /// Read file header
    fn read_header(&mut self) -> Result<RecorderHeader, RecorderError>;

    /// Read next record block
    fn read_block(&mut self) -> Result<Option<RecordBlock>, RecorderError>;

    /// Read file footer
    fn read_footer(&mut self) -> Result<RecorderFooter, RecorderError>;

    /// Read file footer
    fn read_footer_from_end(&mut self) -> Result<RecorderFooter, RecorderError>;

    /// Seek to specified position
    fn seek(&mut self, pos: u64) -> Result<u64, RecorderError>;

    /// Get current read position
    fn position(&mut self) -> Result<u64, RecorderError>;

    /// Get file size
    fn file_size(&mut self) -> Result<u64, RecorderError>;
}

/// Protobuf format writer
pub struct ProtobufWriter {
    pub writer: BufWriter<File>,
}

impl ProtobufWriter {
    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, RecorderError> {
        let file = OpenOptions::new()
            .create(true)
            .write(true)
            .truncate(true)
            .open(path)?;

        Ok(Self {
            writer: BufWriter::new(file),
        })
    }

    /// Write length-prefixed Protobuf message (with type identifier)
    fn write_message<M>(&mut self, msg: &M, msg_type: MessageType) -> Result<(), RecorderError>
    where
        M: prost::Message,
    {
        // Serialize using prost
        let mut buf = Vec::new();
        fn fun_name(e: prost::EncodeError) -> RecorderError {
            RecorderError::ProtobufEncoding(e)
        }
        msg.encode(&mut buf).map_err(fun_name)?;

        // Write message type (1 byte)
        self.writer.write_all(&[msg_type as u8])?;

        // Write message length (4 bytes big endian)
        let len = buf.len() as u32;
        self.writer.write_all(&len.to_be_bytes())?;

        // Write message content
        self.writer.write_all(&buf)?;

        // Write footer length (4 bytes big endian) for reading footer easily if it's the footer
        if msg_type == MessageType::Footer {
            let footerlen = len + 1 + 4 + 4;
            self.writer.write_all(&(footerlen.to_be_bytes()))?;
        }

        Ok(())
    }
}

impl RecordWriter for ProtobufWriter {
    fn write_header(&mut self, header: &RecorderHeader) -> Result<(), RecorderError> {
        self.write_message(header, MessageType::Header)
    }

    fn write_block(&mut self, block: &RecordBlock) -> Result<(), RecorderError> {
        self.write_message(block, MessageType::Block)
    }

    fn write_footer(&mut self, footer: &RecorderFooter) -> Result<(), RecorderError> {
        self.write_message(footer, MessageType::Footer)
    }

    fn flush(&mut self) -> Result<(), RecorderError> {
        self.writer.flush().map_err(RecorderError::Io)
    }

    fn position(&mut self) -> Result<u64, RecorderError> {
        self.writer.stream_position().map_err(RecorderError::Io)
    }

    fn sync(&mut self) -> Result<(), RecorderError> {
        self.writer.get_mut().sync_all().map_err(RecorderError::Io)
    }
}

/// Protobuf format reader
pub struct ProtobufReader {
    reader: BufReader<File>,
}

impl ProtobufReader {
    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, RecorderError> {
        let file = File::open(path)?;

        Ok(Self {
            reader: BufReader::new(file),
        })
    }

    /// Read length-prefixed Protobuf message (with type identifier)
    fn read_message<M>(&mut self) -> Result<Option<(MessageType, M)>, RecorderError>
    where
        M: Message + Default,
    {
        // Read message type
        let mut type_buf = [0u8; 1];
        match self.reader.read_exact(&mut type_buf) {
            Ok(()) => {}
            Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
            Err(e) => return Err(RecorderError::Io(e)),
        }

        let msg_type = match type_buf[0] {
            1 => MessageType::Header,
            2 => MessageType::Block,
            3 => MessageType::Footer,
            t => {
                return Err(RecorderError::FormatError {
                    message: format!("Unknown message type: {}", t),
                });
            }
        };

        // Read message length
        let mut len_buf = [0u8; 4];
        self.reader.read_exact(&mut len_buf)?;

        let len = u32::from_be_bytes(len_buf) as usize;
        if len > 100 * 1024 * 1024 {
            // 100MB limit
            return Err(RecorderError::FormatError {
                message: format!("Message too large: {} bytes", len),
            });
        }

        // Read message content
        let mut buf = vec![0u8; len];
        self.reader.read_exact(&mut buf)?;

        // Parse message
        let msg = M::decode(&buf[..])
            .map_err(|e| -> RecorderError { RecorderError::ProtobufDecoding(e) })?;
        Ok(Some((msg_type, msg)))
    }

    /// Read the type and content of the next message (raw bytes)
    fn read_any_message(&mut self) -> Result<Option<(MessageType, Vec<u8>)>, RecorderError> {
        // Read message type
        let mut type_buf = [0u8; 1];
        match self.reader.read_exact(&mut type_buf) {
            Ok(()) => {}
            Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
            Err(e) => return Err(RecorderError::Io(e)),
        }

        let msg_type = match type_buf[0] {
            1 => MessageType::Header,
            2 => MessageType::Block,
            3 => MessageType::Footer,
            t => {
                return Err(RecorderError::FormatError {
                    message: format!("Unknown message type: {}", t),
                });
            }
        };

        // Read message length
        let mut len_buf = [0u8; 4];
        self.reader.read_exact(&mut len_buf)?;

        let len = u32::from_be_bytes(len_buf) as usize;
        if len > 100 * 1024 * 1024 {
            // 100MB limit
            return Err(RecorderError::FormatError {
                message: format!("Message too large: {} bytes", len),
            });
        }

        // Read message content
        let mut buf = vec![0u8; len];
        self.reader.read_exact(&mut buf)?;

        Ok(Some((msg_type, buf)))
    }
}

impl RecordReader for ProtobufReader {
    fn read_header(&mut self) -> Result<RecorderHeader, RecorderError> {
        match self.read_message::<RecorderHeader>()? {
            Some((MessageType::Header, header)) => Ok(header),
            Some((msg_type, _)) => Err(RecorderError::FormatError {
                message: format!("Expected header, found {:?}", msg_type),
            }),
            None => Err(RecorderError::FormatError {
                message: "Missing file header".to_string(),
            }),
        }
    }

    fn read_block(&mut self) -> Result<Option<RecordBlock>, RecorderError> {
        match self.read_any_message()? {
            Some((MessageType::Block, buf)) => {
                let block = RecordBlock::decode(&buf[..])
                    .map_err(|e| -> RecorderError { RecorderError::ProtobufDecoding(e) })?;
                Ok(Some(block))
            }
            Some((MessageType::Header, _)) => Err(RecorderError::FormatError {
                message: "Unexpected header in block stream".to_string(),
            }),
            Some((MessageType::Footer, _)) | None => Ok(None), // End of file or footer reached
        }
    }

    fn read_footer(&mut self) -> Result<RecorderFooter, RecorderError> {
        match self.read_message::<RecorderFooter>()? {
            Some((MessageType::Footer, footer)) => Ok(footer),
            Some((msg_type, _)) => Err(RecorderError::FormatError {
                message: format!("Expected footer, found {:?}", msg_type),
            }),
            None => Err(RecorderError::FormatError {
                message: "Missing file footer".to_string(),
            }),
        }
    }

    fn read_footer_from_end(&mut self) -> Result<RecorderFooter, RecorderError> {
        let current_pos = self.position()?;

        // Seek to fixed footer position (last 1+4+X+4 bytes)
        self.reader.seek(SeekFrom::End(-(4_i64)))?;
        let mut buf_len = [0u8; 4];
        self.reader.read_exact(&mut buf_len)?;

        let len = u32::from_be_bytes(buf_len) as usize;
        self.reader.seek(SeekFrom::End(-(len as i64)))?;

        // Read footer
        let footer = self.read_footer()?;

        // Restore original position
        self.reader.seek(SeekFrom::Start(current_pos))?;
        Ok(footer)
    }

    fn seek(&mut self, pos: u64) -> Result<u64, RecorderError> {
        self.reader
            .seek(SeekFrom::Start(pos))
            .map_err(RecorderError::Io)
    }

    fn position(&mut self) -> Result<u64, RecorderError> {
        self.reader.stream_position().map_err(RecorderError::Io)
    }

    fn file_size(&mut self) -> Result<u64, RecorderError> {
        let current_pos = self.position()?;
        let size = self.reader.seek(SeekFrom::End(0))?;
        self.reader.seek(SeekFrom::Start(current_pos))?;
        Ok(size)
    }
}