空白そらしろfeat: anyhow
22893e82创建于 2025年6月4日历史提交
// Copyright (c) 2025, Huawei Technologies Co., Ltd.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


use std::{
    fs::File,
    io::{BufReader, Read},
};

use anyhow::Result;
use prost::Message;
use smartstring::alias::String;

pub mod geir;
pub use geir::*;

pub mod mindir;
pub use mindir::*;

pub mod onnx;
pub use onnx::*;

use crate::{AttrValue, AttrValue::*};

/// ## Note
/// This generic function is designed for parsing proto buffer files,
/// using bytes parsing, which is very high performance.
///
/// We may never not support proto buffer text files.
/// For comparison:
/// - `30-line` [`pbtxt`] file (using [`protobuf`]) takes `1.4ms`,
/// - `1800-line` [`pb`] file (using [`prost`]) takes `480μs`.
///
/// Instead, we recommend using the provided [`python`] script or writing one yourself,
/// which only takes several lines to convert the pbtxt format to pb format
///
/// ```py
/// message = ModelDef()
/// with open(pbtxt_file_path, 'r') as f:
///     text_format.Merge(f.read(), message)
/// with open(binary_file_path, 'wb') as f:
///     f.write(message.SerializeToString())
/// ```
///
/// ## Performance Improvement
/// If you parse binary directly without using prost,
/// which can reduce memory allocation and loop process
fn parse_pb<T>(path: &str) -> Result<T>
where
    T: Default + Message,
{
    let file = File::open(path)?;
    let mut reader = BufReader::new(file);
    let mut buffer = vec![];
    reader.read_to_end(&mut buffer)?;

    let model: T = Message::decode(&*buffer)?;

    Ok(model)
}

/// ## Explanation
/// `Tensor` usually has two attributes: [`dtype`] and [`dims`],
/// and sometimes also has attributes such as [`device`] and [`layout`].
/// If this structure is sent directly to the front-end for processing,
/// most of the time, the front-end only do a string concatenation,
/// which brings the following disadvantages:
/// - Increased communication data volume
/// - Deeper object hierarchy (which may cause a sharp increase in JS GC time)
/// - Slower string concatenation that must be run [`every-time`] rendering
///
/// ## Note
/// However, it cannot be ignored that if the edge width needs to be adjusted
/// according to the data dimension, the design needs to be reconsidered.
trait TensorFormatter {
    fn fmt(&self) -> String;
}

#[inline]
fn format_tensors<T>(tensors: &[T]) -> AttrValue
where
    T: TensorFormatter,
{
    if tensors.len() < 2 {
        return StringLike(tensors[0].fmt());
    }
    TensorVals(tensors.iter().map(|t| t.fmt()).collect())
}