| 随机数算子重构优化
Co-authored-by: fenglin28<fenglin28@huawei.com>
# message auto-generated for no-merge-commit merge:
!949 merge hh into master
随机数算子重构优化
Created-by: guankarl
Commit-by: fenglin28
Merged-by: cann-robot
Description: ## 描述
<!--在这里详细描述你的改动,包括改动的原因和所采取的方法。-->
随机数算子重构优化,使用统一模板实现。
## 关联的Issue
<!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。-->
https://gitcode.com/cann/ops-math/issues/552
<!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。-->
## 测试
<!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。-->
## 文档更新
<!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。-->
## 类型标签
<!-- [x] 表示选中 -->
- [ ] Bug修复
- [ ] 新特性
- [ ] 性能优化
- [ ] 文档更新
- [x ] 其他,请描述:
随机数算子重构,总共分为 5 个部分:
算子infershape/inferdtype
实现了公共逻辑提取CommonInferShape,不同算子根据算子原型配置对应mode,新算子开发10行代码搞定
示例:
static graphStatus InferShapeRandomUniformV2(gert::InferShapeContext* context)
{
const std::unordered_map<std::string, size_t>& input_map = {
{"shape", RANDOM_UNIFORM_V2_X}, {"offset", RANDOM_UNIFORM_V2_OFFSET}};
const std::unordered_map<std::string, size_t>& output_map = {
{"y", RANDOM_UNIFORM_V2_Y}, {"offset", RANDOM_UNIFORM_V2_OFFSET}};
int32_t mode = RANDOM_UNIFORM_V2_MODE_TYPE;
return ops::common::CommonInferShape(context, input_map, output_map, mode);
}
IMPL_OP_INFERSHAPE(RandomUniformV2).InferShape(InferShapeRandomUniformV2);
算子信息库注册def
算子的不同输入如果有多组dtype时候,在算子信息库的枚举中,可能会有几十种组合,手写容易漏掉或者重复,且不好检查,通过提供RandomDtypeFmtGen类,调用对应接口即可获取全部组合类型。
示例:
gen.GetSequence("inOutType")
this->Input("x")
.ParamType(REQUIRED)
.DataType({gen.GetSequence("inOutType")})
.Format({baseFormatSeq})
.UnknownShapeFormat({baseFormatSeq});
融合规则
有状态算子转无状态算子,存在范式的融合规则,通过提取公共类,定义公共接口,后续类似融合规则无需重新开发,直接调用对应接口即可。
class FusionRandomUtils {
public:
FusionRandomUtils() = default;
~FusionRandomUtils() = default;
// FusionRandomUtils(const FusionRandomUtils&) = delete;
// FusionRandomUtils& operator=(const FusionRandomUtils&) = delete;
/**
* @param graph 计算图对象
* @param opDesc 算子描述指针
* @param offsetNode 输出:创建的offset Variable节点
* @param shapeDesc Shape输入的Tensor描述
* @param outputDesc 主输出Y的Tensor描述
* @param opNode 基准算子节点(用于拼接新节点名)
* @param fusionPassName 日志标识(用于日志溯源)
* @param fusedOpType 融合算子类型(用于日志/错误信息溯源)
* @return Status SUCCESS/FAILED/PARAM_INVALID
*/
static bool CheckSocVersion(const std::string& fusedOpType);
static Status AddVariableNode(
ge::ComputeGraph& graph, ge::NodePtr opNode, const ge::GeTensorDesc& offsetDesc,
ge::NodePtr& newNode, const uint8_t* dataPtr, size_t size, const std::string& fusionPassName);
static Status CreateInputOpDesc(
ge::OpDescPtr& opDesc, ge::GeTensorDesc shapeDesc, ge::GeTensorDesc offsetDesc, const std::string& fusedOpType);
static Status CreateOutputOpDesc(
ge::OpDescPtr& opDesc, ge::GeTensorDesc outputDesc, ge::GeTensorDesc offsetDesc, const std::string& fusedOpType);
static Status AddOpNodeAndDesc(ge::ComputeGraph& graph, ge::OpDescPtr& opDesc, ge::NodePtr& offsetNode,
ge::GeTensorDesc shapeDesc, ge::GeTensorDesc outputDesc, ge::NodePtr opNode,
const std::string& fusionPassName, const std::string& fusedOpType);
static Status UpdateAttr(ge::NodePtr opNode, int64_t seed, int64_t seed2, ge::DataType dtype,
const std::string& fusionPassName, const std::string& fusedOpType);
static Status CreateNode(ge::ComputeGraph& graph, ge::NodePtr opNode, std::vector<ge::NodePtr>& fusionNodes,
ge::NodePtr& opV2Node, const std::string& fusionPassName, const std::string& fusedOpType,
const std::set<ge::DataType> aicoreDtypeSupportList);
static Status RemoveNode(ge::NodePtr node, ge::ComputeGraph& graph, const std::string& fusedOpType);
static Status ReplaceNode(ge::NodePtr oldNode, ge::NodePtr newNode, ge::ComputeGraph& graph, const std::string& fusedOpType, const std::string& fusionPassName);
};
算子tiling
抽象并统一流程,更加适合随机数算子,统一tilingData,按照配置做检查。
统一流程:
ge::graphStatus RandomTilingArch35::DoTiling()
{
opName_ = context_->GetNodeName();
OP_LOGD(opName_, "Start tiling for op: %s", opName_.c_str());
// 步骤1:校验输入输出和属性
auto ret = CheckInputsOutputsAndAttrs();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Check inputs/outputs/attrs failed");
return ret;
}
// 步骤2: 获取硬件信息
ret = GetPlatformInfo();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Get platform info failed");
return ret;
}
// 步骤3: 填充TilingData
ret = FillUnifiedTilingData();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Fill tiling data failed");
return ret;
}
// 步骤4:计算tilingKey和workspace
ret = CalcTilingKeyAndWorkspace();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Calc tiling key/workspace failed");
return ret;
}
// 步骤5:后置处理(可选)
ret = UniqueProcess();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Unique process failed");
return ret;
}
// 步骤6:写入context
ret = WriteBackToContext();
if (ret != ge::GRAPH_SUCCESS) {
OP_LOGE(opName_, "Write tiling data to context failed");
return ret;
}
// 步骤7:调用dump函数
auto info = tilingData_.DumpTilingInfo();
OP_LOGI("RandomTiling", "%s", info.str().c_str());
OP_LOGD(opName_, "Tiling success for op: %s", opName_);
return ge::GRAPH_SUCCESS;
}
统一tilingData
int64_t usedCoreNum = 0;
int64_t normalCoreProNum = 0;
int64_t tailCoreProNum = 0;
int64_t singleBufferSize = 0;
uint32_t key[2] = {0};
uint32_t counter[4] = {0};
int64_t outputSize = 0;
int64_t probTensorSize = 0;
int64_t sharedTmpBufSize = 0;
当前tilingData较少,后续如有需要,按需添加
按照配置做检查:
OpTilingConfig config;
config.inputCheckRules = {
// 输入索引: dtype列表,shapeSize,dim_num
{0, {{ge::DT_INT32, ge::DT_INT64}, -1, {1}, nullptr}}, // shape
{1, {{ge::DT_INT64}, 1, {}, nullptr}}, // offset
};
config.outputCheckRules = {
// 输出索引: dtype列表,shapeSize,dim_num
{0, {{ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}, -1, {1,2,3,4,5,6,7,8}, nullptr}}
}; // y
算子kernel
kernel抽取了RandomKernelBaseOp类,专门管理tilingData,内置了Skip函数和生成随机数函数,其他算子只需调用VarsInit接口即可
See merge request: cann/ops-math!949 | 4 个月前 |