ModelEvaluator
产品支持情况
| 产品 |
是否支持 |
| Ascend 950PR/Ascend 950DT |
√ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明
针对某一个模型,根据模型的bin类型输入数据,提供一个Python实例,可对该模型执行校准和推理的评估器。
函数原型
class ModelEvaluator(AutoCalibrationEvaluatorBase):
def __init__(self, data_dir, input_shape, data_types):
参数说明
参数名
|
输入/输出
|
说明
|
data_dir
|
输入
|
含义:与模型匹配的bin格式数据集路径。
数据类型:string
参数值格式:"data/input1/;data/input2/"
使用约束:
- 路径支持大小写字母(a-z,A-Z)、数字(0-9)、下划线(_)、中划线(-)、句点(.)、中文字符。
- 若模型有多个输入,且每个输入有多个batch数据,则不同的输入数据必须存储在不同的目录中,目录中文件的名称必须按照升序排序。所有的输入数据路径必须放在双引号中,节点中间使用英文分号分隔。
- 单个bin文件中存储的数组shape需要和input_shape中输入的shape相匹配,例如:单张图片bin存储的数组shape为1x224x224x3,则input_shape中输入的必须为1x224x224x3;如需多个bin做量化,则可通过调整batch_num取值实现。
|
input_shape
|
输入
|
含义:模型输入的shape信息。
数据类型:string
参数值格式:"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2"。
使用约束:指定的节点必须放在双引号中,节点中间使用英文分号分隔。
|
data_types
|
输入
|
含义:输入数据的类型。
数据类型:string
参数值格式:"float32;float64"
使用约束:若模型有多个输入,且数据类型不同,则需要分别指定不同输入的数据类型,指定的输入数据类型必须按照输入节点顺序依次放在双引号中,所有的输入数据类型必须放在双引号中,中间使用英文分号分隔。
|
返回值说明
一个Python实例。
调用示例
import amct_pytorch as amct
evaluator = amct.ModelEvaluator(
data_dir="./data/input_bin/",
input_shape="input:32,3,224,224",
data_types="float32")