#include "services/webnn/dml/graph_impl_dml.h"
#ifdef UNSAFE_BUFFERS_BUILD
#pragma allow_unsafe_buffers
#endif
#include <winerror.h>
#include <algorithm>
#include <array>
#include <iterator>
#include <limits>
#include <numeric>
#include <variant>
#include "base/bits.h"
#include "base/check.h"
#include "base/containers/fixed_flat_set.h"
#include "base/containers/span.h"
#include "base/feature_list.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/metrics/histogram_macros.h"
#include "base/notreached.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/task/thread_pool.h"
#include "base/trace_event/trace_event.h"
#include "base/types/expected.h"
#include "base/types/expected_macros.h"
#include "base/types/optional_ref.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/error.h"
#include "services/webnn/dml/graph_builder_dml.h"
#include "services/webnn/dml/tensor_desc.h"
#include "services/webnn/dml/tensor_impl_dml.h"
#include "services/webnn/dml/utils.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/graph_validation_utils.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/webnn_device.mojom.h"
#include "services/webnn/public/mojom/webnn_error.mojom.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_utils.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_set.h"
#include "third_party/fp16/src/include/fp16.h"
namespace webnn::dml {
namespace {
BASE_FEATURE(kApplyGraphFusion, base::FEATURE_ENABLED_BY_DEFAULT);
using Microsoft::WRL::ComPtr;
using mojom::Operand;
using mojom::OperandPtr;
using mojom::Operation;
using IdToNodeOutputMap = absl::flat_hash_map<OperandId, const NodeOutput*>;
static constexpr auto kDmlFloatDataTypes =
base::MakeFixedFlatSet<DML_TENSOR_DATA_TYPE>(
{DML_TENSOR_DATA_TYPE_FLOAT32, DML_TENSOR_DATA_TYPE_FLOAT16});
DML_SCALAR_UNION ToScalarUnion(const MLNumber& value,
DML_TENSOR_DATA_TYPE type) {
switch (type) {
case DML_TENSOR_DATA_TYPE_FLOAT32:
return DML_SCALAR_UNION{.Float32 = value.AsFloat32()};
case DML_TENSOR_DATA_TYPE_FLOAT16:
return DML_SCALAR_UNION{.UInt16 = value.AsFloat16()};
case DML_TENSOR_DATA_TYPE_INT8:
return DML_SCALAR_UNION{.Int8 = value.AsInt8()};
case DML_TENSOR_DATA_TYPE_UINT8:
return DML_SCALAR_UNION{.UInt8 = value.AsUint8()};
case DML_TENSOR_DATA_TYPE_INT64:
return DML_SCALAR_UNION{.Int64 = value.AsInt64()};
case DML_TENSOR_DATA_TYPE_UINT64:
return DML_SCALAR_UNION{.UInt64 = value.AsUint64()};
case DML_TENSOR_DATA_TYPE_INT32:
return DML_SCALAR_UNION{.Int32 = value.AsInt32()};
case DML_TENSOR_DATA_TYPE_UINT32:
return DML_SCALAR_UNION{.UInt32 = value.AsUint32()};
default:
NOTREACHED() << "[WebNN] This data type is not supported.";
}
}
DML_TENSOR_DATA_TYPE GetTensorDataType(OperandDataType type) {
switch (type) {
case OperandDataType::kFloat32:
return DML_TENSOR_DATA_TYPE_FLOAT32;
case OperandDataType::kFloat16:
return DML_TENSOR_DATA_TYPE_FLOAT16;
case OperandDataType::kInt8:
return DML_TENSOR_DATA_TYPE_INT8;
case OperandDataType::kUint8:
return DML_TENSOR_DATA_TYPE_UINT8;
case OperandDataType::kInt64:
return DML_TENSOR_DATA_TYPE_INT64;
case OperandDataType::kUint64:
return DML_TENSOR_DATA_TYPE_UINT64;
case OperandDataType::kInt32:
return DML_TENSOR_DATA_TYPE_INT32;
case OperandDataType::kUint32:
return DML_TENSOR_DATA_TYPE_UINT32;
case OperandDataType::kInt4:
return DML_TENSOR_DATA_TYPE_INT4;
case OperandDataType::kUint4:
return DML_TENSOR_DATA_TYPE_UINT4;
}
}
OperandDataType DmlDataTypeToOperand(DML_TENSOR_DATA_TYPE type) {
switch (type) {
case DML_TENSOR_DATA_TYPE_FLOAT32:
return OperandDataType::kFloat32;
case DML_TENSOR_DATA_TYPE_FLOAT16:
return OperandDataType::kFloat16;
case DML_TENSOR_DATA_TYPE_INT8:
return OperandDataType::kInt8;
case DML_TENSOR_DATA_TYPE_UINT8:
return OperandDataType::kUint8;
case DML_TENSOR_DATA_TYPE_INT64:
return OperandDataType::kInt64;
case DML_TENSOR_DATA_TYPE_UINT64:
return OperandDataType::kUint64;
case DML_TENSOR_DATA_TYPE_INT32:
return OperandDataType::kInt32;
case DML_TENSOR_DATA_TYPE_UINT32:
return OperandDataType::kUint32;
case DML_TENSOR_DATA_TYPE_INT4:
return OperandDataType::kInt4;
case DML_TENSOR_DATA_TYPE_UINT4:
return OperandDataType::kUint4;
default:
NOTREACHED() << "[WebNN] This data type is not supported.";
}
}
DML_REDUCE_FUNCTION MapReduceKindToReduceFuntion(mojom::Reduce::Kind kind) {
switch (kind) {
case mojom::Reduce::Kind::kL1:
return DML_REDUCE_FUNCTION_L1;
case mojom::Reduce::Kind::kL2:
return DML_REDUCE_FUNCTION_L2;
case mojom::Reduce::Kind::kLogSum:
return DML_REDUCE_FUNCTION_LOG_SUM;
case mojom::Reduce::Kind::kLogSumExp:
return DML_REDUCE_FUNCTION_LOG_SUM_EXP;
case mojom::Reduce::Kind::kMax:
return DML_REDUCE_FUNCTION_MAX;
case mojom::Reduce::Kind::kMean:
return DML_REDUCE_FUNCTION_AVERAGE;
case mojom::Reduce::Kind::kMin:
return DML_REDUCE_FUNCTION_MIN;
case mojom::Reduce::Kind::kProduct:
return DML_REDUCE_FUNCTION_MULTIPLY;
case mojom::Reduce::Kind::kSum:
return DML_REDUCE_FUNCTION_SUM;
case mojom::Reduce::Kind::kSumSquare:
return DML_REDUCE_FUNCTION_SUM_SQUARE;
}
}
void CheckInputDataTypeForReduce(const DataTypeLimits& data_type_limits,
mojom::Reduce::Kind kind,
OperandDataType data_type) {
switch (kind) {
case mojom::Reduce::Kind::kL1:
CHECK(data_type_limits.reduce_l1_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kL2:
CHECK(data_type_limits.reduce_l2_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kLogSum:
CHECK(data_type_limits.reduce_log_sum_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kLogSumExp:
CHECK(
data_type_limits.reduce_log_sum_exp_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kMax:
CHECK(data_type_limits.reduce_max_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kMean:
CHECK(data_type_limits.reduce_mean_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kMin:
CHECK(data_type_limits.reduce_min_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kProduct:
CHECK(data_type_limits.reduce_product_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kSum:
CHECK(data_type_limits.reduce_sum_input.data_types.Has(data_type));
break;
case mojom::Reduce::Kind::kSumSquare:
CHECK(data_type_limits.reduce_sum_square_input.data_types.Has(data_type));
break;
}
}
DML_RECURRENT_NETWORK_DIRECTION MojoRecurrentNetworkDirectionToDml(
mojom::RecurrentNetworkDirection direction) {
switch (direction) {
case mojom::RecurrentNetworkDirection::kForward:
return DML_RECURRENT_NETWORK_DIRECTION_FORWARD;
case mojom::RecurrentNetworkDirection::kBackward:
return DML_RECURRENT_NETWORK_DIRECTION_BACKWARD;
case mojom::RecurrentNetworkDirection::kBoth:
return DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL;
}
}
base::expected<void, mojom::ErrorPtr> CreateUnexpectedError(
mojom::Error::Code error_code,
const std::string& error_message,
std::string_view label) {
return base::unexpected(CreateError(error_code, error_message, label));
}
std::optional<AlignedByteLength<OperandId>> CalculateAlignedByteLength(
const base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands) {
base::CheckedNumeric<size_t> total_byte_length(0);
absl::flat_hash_map<OperandId, D3D12_RANGE> key_to_d3d12_range_map;
for (const auto& [operand_id, constant_operand] : constant_operands) {
auto& d3d12_range = key_to_d3d12_range_map[operand_id];
d3d12_range.Begin = total_byte_length.ValueOrDie();
total_byte_length +=
base::bits::AlignUp<size_t>(constant_operand->ByteSpan().size(),
DML_MINIMUM_BUFFER_TENSOR_ALIGNMENT);
if (!total_byte_length.IsValid()) {
LOG(ERROR) << "[WebNN] Failed to calculate the total byte length.";
return std::nullopt;
}
d3d12_range.End = total_byte_length.ValueOrDie();
}
return AlignedByteLength<OperandId>{
.total_byte_length = total_byte_length.ValueOrDie(),
.key_to_d3d12_range_map = std::move(key_to_d3d12_range_map)};
}
struct UploadAndDefaultBuffers {
ComPtr<ID3D12Resource> upload_buffer;
ComPtr<ID3D12Resource> default_buffer;
};
base::expected<absl::flat_hash_map<OperandId, DML_BUFFER_BINDING>, HRESULT>
UploadAndCreateConstantBufferBinding(
CommandRecorder* command_recorder,
const base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
const AlignedByteLength<OperandId>& aligned_byte_length,
std::variant<UploadAndDefaultBuffers, ComPtr<ID3D12Resource>>
buffer_variant) {
void* mapped_buffer = nullptr;
ID3D12Resource* buffer_to_map = nullptr;
ID3D12Resource* buffer_to_bind = nullptr;
ComPtr<ID3D12Resource> cpu_buffer;
ComPtr<ID3D12Resource> upload_buffer;
ComPtr<ID3D12Resource> default_buffer;
if (std::holds_alternative<ComPtr<ID3D12Resource>>(buffer_variant)) {
cpu_buffer = std::move(std::get<ComPtr<ID3D12Resource>>(buffer_variant));
buffer_to_map = cpu_buffer.Get();
buffer_to_bind = buffer_to_map;
} else {
upload_buffer = std::move(
std::get<UploadAndDefaultBuffers>(buffer_variant).upload_buffer);
default_buffer = std::move(
std::get<UploadAndDefaultBuffers>(buffer_variant).default_buffer);
buffer_to_map = upload_buffer.Get();
buffer_to_bind = default_buffer.Get();
}
CHECK(buffer_to_map);
CHECK(buffer_to_bind);
RETURN_UNEXPECTED_IF_FAILED(buffer_to_map->Map(0, nullptr, &mapped_buffer));
absl::flat_hash_map<OperandId, DML_BUFFER_BINDING> key_to_buffer_binding_map;
for (auto& [operand_id, constant_operand] : constant_operands) {
const auto& d3d12_range =
aligned_byte_length.key_to_d3d12_range_map.at(operand_id);
auto mapped_buffer_span =
base::span(static_cast<uint8_t*>(mapped_buffer) + d3d12_range.Begin,
constant_operand->descriptor().PackedByteLength());
mapped_buffer_span.copy_from(constant_operand->ByteSpan());
auto size_in_bytes = d3d12_range.End - d3d12_range.Begin;
key_to_buffer_binding_map[operand_id] =
DML_BUFFER_BINDING{.Buffer = buffer_to_bind,
.Offset = d3d12_range.Begin,
.SizeInBytes = size_in_bytes};
}
buffer_to_map->Unmap(0, nullptr);
if (std::holds_alternative<ComPtr<ID3D12Resource>>(buffer_variant)) {
CHECK(cpu_buffer);
command_recorder->ReferenceCommandResources(std::move(cpu_buffer));
} else {
CHECK(default_buffer);
CHECK(upload_buffer);
UploadBufferWithBarrier(command_recorder, std::move(default_buffer),
std::move(upload_buffer),
aligned_byte_length.total_byte_length);
}
return key_to_buffer_binding_map;
}
const Operand& GetOperand(const std::vector<OperandPtr>& operands,
OperandId operand_id) {
return *operands.at(operand_id.value());
}
uint32_t CreateInputNode(const std::vector<OperandPtr>& operands,
OperandId input_id,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const Operand& operand = GetOperand(operands, input_id);
CHECK_EQ(operand.kind, Operand::Kind::kInput);
TensorDesc input_tensor_desc(
GetTensorDataType(operand.descriptor.data_type()), DML_TENSOR_FLAG_NONE,
operand.descriptor.shape());
const InputNode* input_node = graph_builder.CreateInputNode();
CHECK(input_node);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(input_node, std::move(input_tensor_desc));
CHECK(node_output);
id_to_node_output_map[input_id] = std::move(node_output);
return input_node->GetGraphInputIndex();
}
void CreateConstantNode(
Adapter* adapter,
OperandId operand_id,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map) {
const OperandDescriptor operand_descriptor =
constant_operands.at(operand_id)->descriptor();
bool should_create_dml_constant_node =
adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_6_2) &&
CalculatePhysicalElementCount(operand_descriptor.shape()) == 1;
const Node* node = nullptr;
if (should_create_dml_constant_node) {
node = graph_builder.CreateConstantNode(
std::move(constant_operands[operand_id]));
constant_operands.erase(operand_id);
} else {
node = graph_builder.CreateInputNode();
constant_id_to_input_index_map[operand_id] =
node->AsInputNode()->GetGraphInputIndex();
}
TensorDesc tensor_desc(GetTensorDataType(operand_descriptor.data_type()),
should_create_dml_constant_node
? DML_TENSOR_FLAG_NONE
: DML_TENSOR_FLAG_OWNED_BY_DML,
operand_descriptor.shape());
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(tensor_desc));
CHECK(id_to_node_output_map.try_emplace(operand_id, output).second);
}
const NodeOutput* GetNodeOutputForOperand(
const IdToNodeOutputMap& id_to_node_output_map,
OperandId operand_id) {
const auto input_iterator = id_to_node_output_map.find(operand_id);
CHECK(input_iterator != id_to_node_output_map.end());
CHECK(input_iterator->second);
return input_iterator->second;
}
const NodeOutput* GetOptionalNodeOutputForOperand(
const IdToNodeOutputMap& id_to_node_output_map,
std::optional<OperandId> operand_id) {
return operand_id.has_value() ? GetNodeOutputForOperand(id_to_node_output_map,
operand_id.value())
: nullptr;
}
const DML_TENSOR_DESC* GetOptionalDmlTensorDescPtr(
base::optional_ref<const TensorDesc> tensor_desc) {
return tensor_desc.has_value() ? &tensor_desc->GetDMLTensorDesc() : nullptr;
}
OperandId BuildConstantOperandForFloatValue(
const ContextProperties& context_properties,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
OperandDataType data_type,
size_t rank,
float value) {
auto descriptor =
*OperandDescriptor::Create(context_properties, data_type,
std::vector<uint32_t>(rank, 1), "constant");
auto constant_operand =
Operand::New(Operand::Kind::kConstant, descriptor, std::nullopt);
OperandId constant_operand_id(graph_info->operands.size());
graph_info->operands.push_back(std::move(constant_operand));
base::HeapArray<uint8_t> buffer;
switch (data_type) {
case OperandDataType::kFloat32:
buffer = base::HeapArray<uint8_t>::CopiedFrom(
base::byte_span_from_ref(base::allow_nonunique_obj, value));
break;
case OperandDataType::kFloat16: {
uint16_t fp16_value = fp16_ieee_from_fp32_value(value);
buffer = base::HeapArray<uint8_t>::CopiedFrom(
base::byte_span_from_ref(fp16_value));
break;
}
default:
LOG(ERROR) << "[WebNN] The data type must be one of the floating point "
"data types.";
NOTREACHED();
}
CHECK(constant_operands
.try_emplace(constant_operand_id,
std::make_unique<WebNNConstantOperand>(
std::move(descriptor), std::move(buffer)))
.second);
return constant_operand_id;
}
const TensorDesc CreateOutputTensorDesc(const std::vector<OperandPtr>& operands,
OperandId output_id) {
const Operand& output_operand = GetOperand(operands, output_id);
return TensorDesc(GetTensorDataType(output_operand.descriptor.data_type()),
output_operand.descriptor.shape());
}
void CreateOperatorNodeForArgMinMax(const std::vector<OperandPtr>& operands,
const mojom::ArgMinMaxPtr& arg_min_max,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, arg_min_max->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const OperandId output_id = arg_min_max->output_operand_id;
const auto& output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const uint32_t axis = arg_min_max->axis;
std::vector<uint32_t> output_dimensions = input_tensor_desc.GetDimensions();
CHECK_LT(axis, output_dimensions.size());
output_dimensions[axis] = 1u;
TensorDesc new_output_tensor_desc(output_tensor_desc.GetDataType(),
std::move(output_dimensions));
DML_OPERATOR_TYPE operator_type;
switch (arg_min_max->kind) {
case mojom::ArgMinMax_Kind::kMin: {
operator_type = DML_OPERATOR_ARGMIN;
break;
}
case mojom::ArgMinMax_Kind::kMax: {
operator_type = DML_OPERATOR_ARGMAX;
break;
}
}
const std::array<const uint32_t, 1> axes = {axis};
DML_ARGMAX_OPERATOR_DESC operator_desc = {};
operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
operator_desc.OutputTensor = &new_output_tensor_desc.GetDMLTensorDesc(),
operator_desc.AxisCount = axes.size();
operator_desc.Axes = axes.data();
operator_desc.AxisDirection =
DML_AXIS_DIRECTION::DML_AXIS_DIRECTION_INCREASING;
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* arg_min_max_node = graph_builder.CreateOperatorNode(
operator_type, &operator_desc, inputs, arg_min_max->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(arg_min_max_node, output_tensor_desc);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
struct ActivationOperatorDesc {
std::variant<DML_ACTIVATION_ELU_OPERATOR_DESC,
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC,
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC,
DML_ACTIVATION_LINEAR_OPERATOR_DESC,
DML_ACTIVATION_RELU_OPERATOR_DESC,
DML_ACTIVATION_SIGMOID_OPERATOR_DESC,
DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC,
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC,
DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC,
DML_ACTIVATION_TANH_OPERATOR_DESC>
desc;
DML_OPERATOR_DESC GetActivationDmlDesc() const {
if (std::holds_alternative<DML_ACTIVATION_ELU_OPERATOR_DESC>(desc)) {
return {DML_OPERATOR_ACTIVATION_ELU,
&std::get<DML_ACTIVATION_ELU_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>(desc)) {
return {DML_OPERATOR_ACTIVATION_HARD_SIGMOID,
&std::get<DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_LEAKY_RELU,
&std::get<DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_LINEAR_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_LINEAR,
&std::get<DML_ACTIVATION_LINEAR_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_RELU_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_RELU,
&std::get<DML_ACTIVATION_RELU_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_SIGMOID_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SIGMOID,
&std::get<DML_ACTIVATION_SIGMOID_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SOFTMAX1,
&std::get<DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SOFTPLUS,
&std::get<DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_SOFTSIGN,
&std::get<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC>(desc)};
} else if (std::holds_alternative<DML_ACTIVATION_TANH_OPERATOR_DESC>(
desc)) {
return {DML_OPERATOR_ACTIVATION_TANH,
&std::get<DML_ACTIVATION_TANH_OPERATOR_DESC>(desc)};
} else {
NOTREACHED() << "The activation type is not supported.";
}
}
};
ActivationOperatorDesc CreateOperatorDescForActivation(
mojom::RecurrentNetworkActivation activation) {
switch (activation) {
case mojom::RecurrentNetworkActivation::kRelu:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_RELU_OPERATOR_DESC{}};
case mojom::RecurrentNetworkActivation::kSigmoid:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}};
case mojom::RecurrentNetworkActivation::kTanh:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_TANH_OPERATOR_DESC{}};
}
}
std::optional<const Operation*> GetFusibleActivationFromOperation(
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
const Operation* operation) {
const auto activation_iterator =
operation_to_fusible_standalone_activation_map.find(operation);
if (activation_iterator !=
operation_to_fusible_standalone_activation_map.end()) {
return activation_iterator->second;
}
return std::optional<const Operation*>();
}
std::optional<OperandId> GetFusibleTransposeInputId(
const absl::flat_hash_map<OperandId,
raw_ptr<const Operation, CtnExperimental>>&
output_id_to_fusible_transpose_map,
OperandId input_id) {
const auto transpose_iterator =
output_id_to_fusible_transpose_map.find(input_id);
if (transpose_iterator != output_id_to_fusible_transpose_map.end()) {
return transpose_iterator->second->get_transpose()->input_operand_id;
}
return std::optional<OperandId>();
}
bool CanElementWiseBinarySupportFusion(
const mojom::ElementWiseBinaryPtr& binary,
const std::vector<OperandPtr>& operands) {
const Operand& output_operand =
GetOperand(operands, binary->output_operand_id);
OperandDataType output_data_type = output_operand.descriptor.data_type();
return binary->kind == mojom::ElementWiseBinary::Kind::kAdd &&
(output_data_type == OperandDataType::kFloat32 ||
output_data_type == OperandDataType::kFloat16);
}
bool CanFuseStandaloneActivation(const Operation* operation,
const std::vector<OperandPtr>& operands) {
switch (operation->which()) {
case Operation::Tag::kElementWiseBinary:
return CanElementWiseBinarySupportFusion(
operation->get_element_wise_binary(), operands);
case Operation::Tag::kConv2d:
case Operation::Tag::kBatchNormalization:
case Operation::Tag::kGemm:
case Operation::Tag::kInstanceNormalization:
case Operation::Tag::kLayerNormalization:
case Operation::Tag::kMatmul:
return true;
default:
return false;
}
}
std::optional<OperandId> GetFusibleActivationOutputId(
const mojom::Operation& operation) {
switch (operation.which()) {
case mojom::Operation::Tag::kElu:
return operation.get_elu()->output_operand_id;
case mojom::Operation::Tag::kHardSigmoid:
return operation.get_hard_sigmoid()->output_operand_id;
case mojom::Operation::Tag::kLeakyRelu:
return operation.get_leaky_relu()->output_operand_id;
case mojom::Operation::Tag::kLinear:
return operation.get_linear()->output_operand_id;
case mojom::Operation::Tag::kRelu:
return operation.get_relu()->output_operand_id;
case mojom::Operation::Tag::kSigmoid:
return operation.get_sigmoid()->output_operand_id;
case mojom::Operation::Tag::kSoftplus:
return operation.get_softplus()->output_operand_id;
case mojom::Operation::Tag::kSoftsign:
return operation.get_softsign()->output_operand_id;
case mojom::Operation::Tag::kTanh:
return operation.get_tanh()->output_operand_id;
default:
return std::nullopt;
}
}
std::string_view GetOperatorLabel(std::string_view original_label,
std::string_view default_label) {
return original_label.empty() ? default_label : original_label;
}
std::string_view GetFusibleActivationLabel(const mojom::Operation& operation) {
switch (operation.which()) {
case mojom::Operation::Tag::kElu:
return GetOperatorLabel(operation.get_elu()->label, "elu");
case mojom::Operation::Tag::kHardSigmoid:
return GetOperatorLabel(operation.get_hard_sigmoid()->label,
"hard_sigmoid");
case mojom::Operation::Tag::kLeakyRelu:
return GetOperatorLabel(operation.get_leaky_relu()->label, "leaky_relu");
case mojom::Operation::Tag::kLinear:
return GetOperatorLabel(operation.get_linear()->label, "linear");
case mojom::Operation::Tag::kRelu:
return GetOperatorLabel(operation.get_relu()->label, "relu");
case mojom::Operation::Tag::kSigmoid:
return GetOperatorLabel(operation.get_sigmoid()->label, "sigmoid");
case mojom::Operation::Tag::kSoftplus:
return GetOperatorLabel(operation.get_softplus()->label, "softplus");
case mojom::Operation::Tag::kSoftsign:
return GetOperatorLabel(operation.get_softsign()->label, "softsign");
case mojom::Operation::Tag::kTanh:
return GetOperatorLabel(operation.get_tanh()->label, "tanh");
default:
NOTREACHED() << "The operation is not a fusible activation.";
}
}
std::string GetFusedOperatorLabel(std::string_view original_label,
std::string_view default_label,
const mojom::Operation& fusible_activation) {
return base::JoinString(
{original_label.empty() ? default_label : original_label,
GetFusibleActivationLabel(fusible_activation)},
"+");
}
ActivationOperatorDesc CreateOperatorDescForFusibleActivation(
const mojom::Operation& activation) {
CHECK(GetFusibleActivationOutputId(activation));
switch (activation.which()) {
case mojom::Operation::Tag::kElu:
return ActivationOperatorDesc{.desc = DML_ACTIVATION_ELU_OPERATOR_DESC{
.Alpha = activation.get_elu()->alpha}};
case mojom::Operation::Tag::kHardSigmoid:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC{
.Alpha = activation.get_hard_sigmoid()->alpha,
.Beta = activation.get_hard_sigmoid()->beta}};
case mojom::Operation::Tag::kLeakyRelu:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC{
.Alpha = activation.get_leaky_relu()->alpha}};
case mojom::Operation::Tag::kLinear:
return ActivationOperatorDesc{.desc = DML_ACTIVATION_LINEAR_OPERATOR_DESC{
.Alpha = activation.get_linear()->alpha,
.Beta = activation.get_linear()->beta}};
case mojom::Operation::Tag::kRelu:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_RELU_OPERATOR_DESC{}};
case mojom::Operation::Tag::kSigmoid:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_SIGMOID_OPERATOR_DESC{}};
case mojom::Operation::Tag::kSoftplus:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC{.Steepness = 1.0}};
case mojom::Operation::Tag::kSoftsign:
return ActivationOperatorDesc{
.desc = DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC{}};
case mojom::Operation::Tag::kTanh:
return ActivationOperatorDesc{.desc =
DML_ACTIVATION_TANH_OPERATOR_DESC{}};
default:
NOTREACHED() << "The operation is not a fusible activation.";
}
}
struct OperationConnectivity {
std::vector<OperandId> input_ids;
std::vector<OperandId> output_ids;
};
void RetrieveOperationConnectivity(
const Operation* operation,
OperationConnectivity& out_operation_connectivity) {
std::vector<OperandId>& input_ids = out_operation_connectivity.input_ids;
std::vector<OperandId>& output_ids = out_operation_connectivity.output_ids;
input_ids.clear();
output_ids.clear();
switch (operation->which()) {
case Operation::Tag::kArgMinMax: {
const auto& arg_min_max = operation->get_arg_min_max();
input_ids = {arg_min_max->input_operand_id};
output_ids = {arg_min_max->output_operand_id};
break;
}
case Operation::Tag::kBatchNormalization: {
const auto& batch_norm = operation->get_batch_normalization();
input_ids = {batch_norm->input_operand_id, batch_norm->mean_operand_id,
batch_norm->variance_operand_id};
auto& scale_operand_id = batch_norm->scale_operand_id;
if (scale_operand_id) {
input_ids.push_back(scale_operand_id.value());
}
auto& bias_operand_id = batch_norm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {batch_norm->output_operand_id};
break;
}
case Operation::Tag::kClamp: {
const auto& clamp = operation->get_clamp();
input_ids = {clamp->input_operand_id};
output_ids = {clamp->output_operand_id};
break;
}
case Operation::Tag::kConcat: {
const auto& concat = operation->get_concat();
input_ids = {concat->input_operand_ids};
output_ids = {concat->output_operand_id};
break;
}
case Operation::Tag::kConv2d: {
const auto& conv2d = operation->get_conv2d();
input_ids = {conv2d->input_operand_id, conv2d->filter_operand_id};
auto& bias_operand_id = conv2d->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {conv2d->output_operand_id};
break;
}
case Operation::Tag::kCumulativeSum: {
const auto& cumulative_sum = operation->get_cumulative_sum();
input_ids = {cumulative_sum->input_operand_id};
output_ids = {cumulative_sum->output_operand_id};
break;
}
case Operation::Tag::kDequantizeLinear: {
const auto& dequantize_linear = operation->get_dequantize_linear();
input_ids = {dequantize_linear->input_operand_id,
dequantize_linear->scale_operand_id,
dequantize_linear->zero_point_operand_id};
output_ids = {dequantize_linear->output_operand_id};
break;
}
case Operation::Tag::kElementWiseBinary: {
const auto& binary = operation->get_element_wise_binary();
input_ids = {binary->lhs_operand_id, binary->rhs_operand_id};
output_ids = {binary->output_operand_id};
break;
}
case Operation::Tag::kElu: {
const auto& elu = operation->get_elu();
input_ids = {elu->input_operand_id};
output_ids = {elu->output_operand_id};
break;
}
case Operation::Tag::kElementWiseUnary: {
const auto& unary = operation->get_element_wise_unary();
input_ids = {unary->input_operand_id};
output_ids = {unary->output_operand_id};
break;
}
case Operation::Tag::kExpand: {
const auto& expand = operation->get_expand();
input_ids = {expand->input_operand_id};
output_ids = {expand->output_operand_id};
break;
}
case Operation::Tag::kGather: {
const auto& gather = operation->get_gather();
input_ids = {gather->input_operand_id, gather->indices_operand_id};
output_ids = {gather->output_operand_id};
break;
}
case Operation::Tag::kGatherElements: {
const auto& gather_elements = operation->get_gather_elements();
input_ids = {gather_elements->input_operand_id,
gather_elements->indices_operand_id};
output_ids = {gather_elements->output_operand_id};
break;
}
case Operation::Tag::kGatherNd: {
const auto& gather_nd = operation->get_gather_nd();
input_ids = {gather_nd->input_operand_id, gather_nd->indices_operand_id};
output_ids = {gather_nd->output_operand_id};
break;
}
case Operation::Tag::kGelu: {
const auto& gelu = operation->get_gelu();
input_ids = {gelu->input_operand_id};
output_ids = {gelu->output_operand_id};
break;
}
case Operation::Tag::kGemm: {
const auto& gemm = operation->get_gemm();
input_ids = {gemm->a_operand_id, gemm->b_operand_id};
auto& c_operand_id = gemm->c_operand_id;
if (c_operand_id) {
input_ids.push_back(c_operand_id.value());
}
output_ids = {gemm->output_operand_id};
break;
}
case Operation::Tag::kGru: {
const auto& gru = operation->get_gru();
input_ids = {gru->input_operand_id, gru->weight_operand_id,
gru->recurrent_weight_operand_id};
auto& bias_operand_id = gru->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = gru->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
auto& initial_hidden_state_operand_id =
gru->initial_hidden_state_operand_id;
if (initial_hidden_state_operand_id) {
input_ids.push_back(initial_hidden_state_operand_id.value());
}
output_ids = {gru->output_operand_ids};
break;
}
case Operation::Tag::kGruCell: {
const auto& gru_cell = operation->get_gru_cell();
input_ids = {gru_cell->input_operand_id, gru_cell->weight_operand_id,
gru_cell->recurrent_weight_operand_id,
gru_cell->hidden_state_operand_id};
auto& bias_operand_id = gru_cell->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = gru_cell->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
output_ids = {gru_cell->output_operand_id};
break;
}
case Operation::Tag::kHardSigmoid: {
const auto& hard_sgmoid = operation->get_hard_sigmoid();
input_ids = {hard_sgmoid->input_operand_id};
output_ids = {hard_sgmoid->output_operand_id};
break;
}
case Operation::Tag::kHardSwish: {
const auto& hard_swish = operation->get_hard_swish();
input_ids = {hard_swish->input_operand_id};
output_ids = {hard_swish->output_operand_id};
break;
}
case Operation::Tag::kInstanceNormalization: {
const auto& instance_norm = operation->get_instance_normalization();
input_ids = {instance_norm->input_operand_id};
auto& scale_operand_id = instance_norm->scale_operand_id;
if (scale_operand_id) {
input_ids.push_back(scale_operand_id.value());
}
auto& bias_operand_id = instance_norm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {instance_norm->output_operand_id};
break;
}
case Operation::Tag::kLayerNormalization: {
const auto& layer_norm = operation->get_layer_normalization();
input_ids = {layer_norm->input_operand_id};
auto& scale_operand_id = layer_norm->scale_operand_id;
if (scale_operand_id) {
input_ids.push_back(scale_operand_id.value());
}
auto& bias_operand_id = layer_norm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
output_ids = {layer_norm->output_operand_id};
break;
}
case Operation::Tag::kLeakyRelu: {
const auto& leaky_relu = operation->get_leaky_relu();
input_ids = {leaky_relu->input_operand_id};
output_ids = {leaky_relu->output_operand_id};
break;
}
case Operation::Tag::kLinear: {
const auto& linear = operation->get_linear();
input_ids = {linear->input_operand_id};
output_ids = {linear->output_operand_id};
break;
}
case Operation::Tag::kLstm: {
const auto& lstm = operation->get_lstm();
input_ids = {lstm->input_operand_id, lstm->weight_operand_id,
lstm->recurrent_weight_operand_id};
auto& bias_operand_id = lstm->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = lstm->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
auto& peephole_weight_operand_id = lstm->peephole_weight_operand_id;
if (peephole_weight_operand_id) {
input_ids.push_back(peephole_weight_operand_id.value());
}
auto& initial_hidden_state_operand_id =
lstm->initial_hidden_state_operand_id;
if (initial_hidden_state_operand_id) {
input_ids.push_back(initial_hidden_state_operand_id.value());
}
auto& initial_cell_state_operand_id = lstm->initial_cell_state_operand_id;
if (initial_cell_state_operand_id) {
input_ids.push_back(initial_cell_state_operand_id.value());
}
output_ids = {lstm->output_operand_ids};
break;
}
case Operation::Tag::kLstmCell: {
const auto& lstm_cell = operation->get_lstm_cell();
input_ids = {lstm_cell->input_operand_id, lstm_cell->weight_operand_id,
lstm_cell->recurrent_weight_operand_id,
lstm_cell->hidden_state_operand_id,
lstm_cell->cell_state_operand_id};
auto& bias_operand_id = lstm_cell->bias_operand_id;
if (bias_operand_id) {
input_ids.push_back(bias_operand_id.value());
}
auto& recurrent_bias_operand_id = lstm_cell->recurrent_bias_operand_id;
if (recurrent_bias_operand_id) {
input_ids.push_back(recurrent_bias_operand_id.value());
}
auto& peephole_weight_operand_id = lstm_cell->peephole_weight_operand_id;
if (peephole_weight_operand_id) {
input_ids.push_back(peephole_weight_operand_id.value());
}
output_ids = {lstm_cell->output_operand_ids};
break;
}
case Operation::Tag::kMatmul: {
const auto& matmul = operation->get_matmul();
input_ids = {matmul->a_operand_id, matmul->b_operand_id};
output_ids = {matmul->output_operand_id};
break;
}
case Operation::Tag::kPad: {
const auto& pad = operation->get_pad();
input_ids = {pad->input_operand_id};
output_ids = {pad->output_operand_id};
break;
}
case Operation::Tag::kPool2d: {
const auto& pool2d = operation->get_pool2d();
input_ids = {pool2d->input_operand_id};
output_ids = {pool2d->output_operand_id};
break;
}
case Operation::Tag::kPrelu: {
const auto& prelu = operation->get_prelu();
input_ids = {prelu->input_operand_id, prelu->slope_operand_id};
output_ids = {prelu->output_operand_id};
break;
}
case Operation::Tag::kQuantizeLinear: {
const auto& quantize_linear = operation->get_quantize_linear();
input_ids = {quantize_linear->input_operand_id,
quantize_linear->scale_operand_id,
quantize_linear->zero_point_operand_id};
output_ids = {quantize_linear->output_operand_id};
break;
}
case Operation::Tag::kReduce: {
const auto& reduce = operation->get_reduce();
input_ids = {reduce->input_operand_id};
output_ids = {reduce->output_operand_id};
break;
}
case Operation::Tag::kRelu: {
const auto& relu = operation->get_relu();
input_ids = {relu->input_operand_id};
output_ids = {relu->output_operand_id};
break;
}
case Operation::Tag::kResample2d: {
const auto& resample2d = operation->get_resample2d();
input_ids = {resample2d->input_operand_id};
output_ids = {resample2d->output_operand_id};
break;
}
case Operation::Tag::kReshape: {
const auto& reshape = operation->get_reshape();
input_ids = {reshape->input_operand_id};
output_ids = {reshape->output_operand_id};
break;
}
case Operation::Tag::kReverse: {
const auto& reverse = operation->get_reverse();
input_ids = {reverse->input_operand_id};
output_ids = {reverse->output_operand_id};
break;
}
case Operation::Tag::kScatterElements: {
const auto& scatter_elements = operation->get_scatter_elements();
input_ids = {scatter_elements->input_operand_id,
scatter_elements->indices_operand_id,
scatter_elements->updates_operand_id};
output_ids = {scatter_elements->output_operand_id};
break;
}
case Operation::Tag::kScatterNd: {
const auto& scatter_nd = operation->get_scatter_nd();
input_ids = {scatter_nd->input_operand_id, scatter_nd->indices_operand_id,
scatter_nd->updates_operand_id};
output_ids = {scatter_nd->output_operand_id};
break;
}
case Operation::Tag::kSigmoid: {
const auto& sigmoid = operation->get_sigmoid();
input_ids = {sigmoid->input_operand_id};
output_ids = {sigmoid->output_operand_id};
break;
}
case Operation::Tag::kSlice: {
const auto& slice = operation->get_slice();
input_ids = {slice->input_operand_id};
output_ids = {slice->output_operand_id};
break;
}
case Operation::Tag::kSoftmax: {
const auto& softmax = operation->get_softmax();
input_ids = {softmax->input_operand_id};
output_ids = {softmax->output_operand_id};
break;
}
case Operation::Tag::kSoftplus: {
const auto& softplus = operation->get_softplus();
input_ids = {softplus->input_operand_id};
output_ids = {softplus->output_operand_id};
break;
}
case Operation::Tag::kSoftsign: {
const auto& softsign = operation->get_softsign();
input_ids = {softsign->input_operand_id};
output_ids = {softsign->output_operand_id};
break;
}
case Operation::Tag::kSplit: {
const auto& split = operation->get_split();
input_ids = {split->input_operand_id};
output_ids = {split->output_operand_ids};
break;
}
case Operation::Tag::kTanh: {
const auto& tanh = operation->get_tanh();
input_ids = {tanh->input_operand_id};
output_ids = {tanh->output_operand_id};
break;
}
case Operation::Tag::kTile: {
const auto& tile = operation->get_tile();
input_ids = {tile->input_operand_id};
output_ids = {tile->output_operand_id};
break;
}
case Operation::Tag::kTranspose: {
const auto& transpose = operation->get_transpose();
input_ids = {transpose->input_operand_id};
output_ids = {transpose->output_operand_id};
break;
}
case Operation::Tag::kTriangular: {
const auto& triangular = operation->get_triangular();
input_ids = {triangular->input_operand_id};
output_ids = {triangular->output_operand_id};
break;
}
case Operation::Tag::kWhere: {
const auto& where = operation->get_where();
input_ids = {where->condition_operand_id, where->true_value_operand_id,
where->false_value_operand_id};
output_ids = {where->output_operand_id};
break;
}
}
}
struct GraphFusionInfo {
absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>
operation_to_fusible_standalone_activation_map;
absl::flat_hash_map<OperandId, raw_ptr<const Operation, CtnExperimental>>
output_id_to_fusible_transpose_map;
absl::flat_hash_set<const Operation*> fusible_operations_set;
};
GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfoPtr& graph_info) {
if (!base::FeatureList::IsEnabled(kApplyGraphFusion)) {
return GraphFusionInfo();
}
absl::flat_hash_map<OperandId, raw_ptr<const Operation, CtnExperimental>>
input_id_to_activation_map;
// relu elu
// relu \
// | \
// [output1] [output2]
absl::flat_hash_map<OperandId, const Operation*> input_id_to_matmul_map;
GraphFusionInfo graph_fusion_info;
base::FixedArray<uint32_t> operand_id_to_use_count_map(
graph_info->operands.size(), 0);
for (OperandId graph_output_id : graph_info->output_operands) {
++operand_id_to_use_count_map[graph_output_id.value()];
}
OperationConnectivity operation_connectivity;
for (size_t operation_index = graph_info->operations.size();
operation_index-- > 0;) {
const auto& operation = graph_info->operations[operation_index];
RetrieveOperationConnectivity(
operation.get(),
operation_connectivity);
for (OperandId input_id : operation_connectivity.input_ids) {
++operand_id_to_use_count_map[input_id.value()];
}
if (GetFusibleActivationOutputId(*operation)) {
CHECK_EQ(operation_connectivity.input_ids.size(), 1U);
input_id_to_activation_map.try_emplace(
operation_connectivity.input_ids[0], operation.get());
} else if (CanFuseStandaloneActivation(operation.get(),
graph_info->operands)) {
CHECK_EQ(operation_connectivity.output_ids.size(), 1U);
OperandId output_id = operation_connectivity.output_ids[0];
const auto activation_iterator =
input_id_to_activation_map.find(output_id);
if (operand_id_to_use_count_map[output_id.value()] == 1 &&
activation_iterator != input_id_to_activation_map.end()) {
const auto* activation = activation_iterator->second.get();
graph_fusion_info.fusible_operations_set.insert(activation);
graph_fusion_info
.operation_to_fusible_standalone_activation_map[operation.get()] =
activation;
}
}
switch (operation->which()) {
case Operation::Tag::kMatmul: {
CHECK_EQ(operation_connectivity.input_ids.size(), 2U);
input_id_to_matmul_map.try_emplace(operation_connectivity.input_ids[0],
operation.get());
input_id_to_matmul_map.try_emplace(operation_connectivity.input_ids[1],
operation.get());
break;
}
case Operation::Tag::kTranspose: {
CHECK_EQ(operation_connectivity.output_ids.size(), 1U);
OperandId output_id = operation_connectivity.output_ids[0];
if (!input_id_to_matmul_map.contains(output_id) ||
operand_id_to_use_count_map[output_id.value()] != 1) {
break;
}
const mojom::TransposePtr& transpose = operation->get_transpose();
const mojom::Operand& input_operand =
GetOperand(graph_info->operands, transpose->input_operand_id);
uint32_t input_rank = input_operand.descriptor.shape().size();
if (input_rank < 2) {
break;
}
std::vector<uint32_t> swap_last_two_axes(input_rank);
std::iota(swap_last_two_axes.begin(), swap_last_two_axes.end(), 0);
std::swap(swap_last_two_axes[input_rank - 2],
swap_last_two_axes[input_rank - 1]);
if (swap_last_two_axes == transpose->permutation) {
graph_fusion_info.fusible_operations_set.insert(operation.get());
graph_fusion_info.output_id_to_fusible_transpose_map[output_id] =
operation.get();
}
break;
}
default: {
break;
}
}
}
CHECK_EQ(
graph_fusion_info.operation_to_fusible_standalone_activation_map.size() +
graph_fusion_info.output_id_to_fusible_transpose_map.size(),
graph_fusion_info.fusible_operations_set.size());
return graph_fusion_info;
}
void CreateOperatorNodeForBatchNormalization(
Adapter* adapter,
const ContextProperties& context_properties,
const Operation* operation,
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map) {
const auto& batch_normalization = operation->get_batch_normalization();
const auto& operands = graph_info->operands;
OperandId input_id = batch_normalization->input_operand_id;
const Operand& input_operand = GetOperand(operands, input_id);
CHECK(context_properties.data_type_limits.batch_normalization_input.Supports(
input_operand.descriptor));
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const auto input_rank = input_tensor_desc.GetDimensions().size();
OperandId output_id = batch_normalization->output_operand_id;
const Operand& output_operand = GetOperand(operands, output_id);
OperandDataType data_type = output_operand.descriptor.data_type();
CHECK(context_properties.data_type_limits.batch_normalization_input.data_types
.Has(data_type));
const TensorDesc output_tensor_desc(GetTensorDataType(data_type),
output_operand.descriptor.shape());
const NodeOutput* mean = GetNodeOutputForOperand(
id_to_node_output_map, batch_normalization->mean_operand_id);
auto mean_tensor_desc = mean->GetTensorDesc();
auto mean_rank = mean_tensor_desc.GetDimensions().size();
CHECK_EQ(mean_rank, 1U);
auto axis = batch_normalization->axis;
uint32_t axes[1] = {axis};
mean_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
const NodeOutput* variance = GetNodeOutputForOperand(
id_to_node_output_map, batch_normalization->variance_operand_id);
auto variance_tensor_desc = variance->GetTensorDesc();
auto variance_rank = variance_tensor_desc.GetDimensions().size();
CHECK_EQ(variance_rank, 1U);
variance_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
OperandId scale_operand_id;
if (batch_normalization->scale_operand_id.has_value()) {
scale_operand_id = batch_normalization->scale_operand_id.value();
} else {
scale_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, data_type,
1, 1.0);
CreateConstantNode(adapter, scale_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
}
const NodeOutput* scale =
GetNodeOutputForOperand(id_to_node_output_map, scale_operand_id);
auto scale_tensor_desc = scale->GetTensorDesc();
auto scale_rank = scale_tensor_desc.GetDimensions().size();
CHECK_EQ(scale_rank, 1U);
scale_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
OperandId bias_operand_id;
if (batch_normalization->bias_operand_id.has_value()) {
bias_operand_id = batch_normalization->bias_operand_id.value();
} else {
bias_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, data_type,
1, 0);
CreateConstantNode(adapter, bias_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
}
const NodeOutput* bias =
GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
auto bias_tensor_desc = bias->GetTensorDesc();
auto bias_rank = bias_tensor_desc.GetDimensions().size();
CHECK_EQ(bias_rank, 1U);
bias_tensor_desc.MakeBroadcastCompatible(input_rank, axes);
std::array<const NodeOutput*, 5> inputs = {input, mean, variance, scale,
bias};
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
std::string label = batch_normalization->label;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
label = GetFusedOperatorLabel(label, "batch_normalization",
*fusible_activation.value());
}
DML_BATCH_NORMALIZATION_OPERATOR_DESC batch_normalization_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.MeanTensor = &mean_tensor_desc.GetDMLTensorDesc(),
.VarianceTensor = &variance_tensor_desc.GetDMLTensorDesc(),
.ScaleTensor = &scale_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = &bias_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Spatial = true,
.Epsilon = batch_normalization->epsilon,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const GraphNode* batch_normalization_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_BATCH_NORMALIZATION, &batch_normalization_operator_desc,
inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
batch_normalization_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForClamp(Adapter* adapter,
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ClampPtr& clamp,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, clamp->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.clamp_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
OperandId output_id = clamp->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const GraphNode* clamp_node = nullptr;
std::array<const NodeOutput*, 1> inputs = {input};
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_0)) {
DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC clamp_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ScaleBias = nullptr,
.MinMaxDataType = output_tensor_desc.GetDataType(),
.Min =
ToScalarUnion(clamp->min_value, output_tensor_desc.GetDataType()),
.Max =
ToScalarUnion(clamp->max_value, output_tensor_desc.GetDataType())};
clamp_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_CLIP1, &clamp_operator_desc, inputs,
clamp->label);
} else {
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clamp_operator_desc = {};
clamp_operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc();
clamp_operator_desc.OutputTensor = &output_tensor_desc.GetDMLTensorDesc();
clamp_operator_desc.ScaleBias = nullptr;
switch (output_tensor_desc.GetDataType()) {
case DML_TENSOR_DATA_TYPE_FLOAT32:
clamp_operator_desc.Min = clamp->min_value.AsFloat32();
clamp_operator_desc.Max = clamp->max_value.AsFloat32();
break;
case DML_TENSOR_DATA_TYPE_FLOAT16:
clamp_operator_desc.Min = clamp->min_value.AsFloat16();
clamp_operator_desc.Max = clamp->max_value.AsFloat16();
break;
case DML_TENSOR_DATA_TYPE_INT8:
clamp_operator_desc.Min = clamp->min_value.AsInt8();
clamp_operator_desc.Max = clamp->max_value.AsInt8();
break;
case DML_TENSOR_DATA_TYPE_UINT8:
clamp_operator_desc.Min = clamp->min_value.AsUint8();
clamp_operator_desc.Max = clamp->max_value.AsUint8();
break;
case DML_TENSOR_DATA_TYPE_INT32:
clamp_operator_desc.Min = clamp->min_value.AsInt32();
clamp_operator_desc.Max = clamp->max_value.AsInt32();
break;
case DML_TENSOR_DATA_TYPE_UINT32:
clamp_operator_desc.Min = clamp->min_value.AsUint32();
clamp_operator_desc.Max = clamp->max_value.AsUint32();
break;
default:
NOTREACHED() << "[WebNN] This data type is not supported.";
}
clamp_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_CLIP, &clamp_operator_desc, inputs,
clamp->label);
}
const NodeOutput* output = graph_builder.CreateNodeOutput(
clamp_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForConcat(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ConcatPtr& concat,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& input_operand_ids = concat->input_operand_ids;
CHECK(std::ranges::all_of(input_operand_ids, [&](OperandId input_operand_id) {
return context_properties.data_type_limits.concat_inputs.Supports(
GetOperand(operands, input_operand_id).descriptor);
}));
size_t input_num = input_operand_ids.size();
base::FixedArray<const NodeOutput*> inputs(input_num);
base::FixedArray<DML_TENSOR_DESC> input_dml_tensor_descs(input_num);
for (size_t i = 0; i < input_num; ++i) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_operand_ids[i]);
inputs[i] = input;
input_dml_tensor_descs[i] = input->GetTensorDesc().GetDMLTensorDesc();
}
OperandId output_id = concat->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_JOIN_OPERATOR_DESC concat_operator_desc{
.InputCount = base::checked_cast<uint32_t>(input_dml_tensor_descs.size()),
.InputTensors = input_dml_tensor_descs.data(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = concat->axis};
const GraphNode* concat_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_operator_desc, inputs, concat->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
concat_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForConv2d(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const Operation* operation,
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& conv2d = operation->get_conv2d();
const Operand& input_operand = GetOperand(operands, conv2d->input_operand_id);
const Operand& filter_operand =
GetOperand(operands, conv2d->filter_operand_id);
DML_CONVOLUTION_DIRECTION conv2d_direction;
switch (conv2d->kind) {
case mojom::Conv2d::Kind::kDirect: {
CHECK(context_properties.data_type_limits.conv2d_input.SupportsAll(
{input_operand.descriptor, filter_operand.descriptor}));
conv2d_direction =
DML_CONVOLUTION_DIRECTION::DML_CONVOLUTION_DIRECTION_FORWARD;
break;
}
case mojom::Conv2d::Kind::kTransposed: {
CHECK(context_properties.data_type_limits.conv_transpose2d_input
.SupportsAll(
{input_operand.descriptor, filter_operand.descriptor}));
conv2d_direction =
DML_CONVOLUTION_DIRECTION::DML_CONVOLUTION_DIRECTION_BACKWARD;
break;
}
}
OperandId output_id = conv2d->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
CHECK_EQ(output_tensor_desc.GetDimensions().size(), 4u);
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, conv2d->input_operand_id);
auto input_tensor_desc = input->GetTensorDesc();
const NodeOutput* filter =
GetNodeOutputForOperand(id_to_node_output_map, conv2d->filter_operand_id);
auto filter_tensor_desc = filter->GetTensorDesc();
std::vector<const NodeOutput*> inputs = {input, filter};
std::optional<TensorDesc> reshaped_bias_tensor_desc;
auto& bias_operand_id = conv2d->bias_operand_id;
if (bias_operand_id) {
const Operand& bias_operand = GetOperand(operands, *bias_operand_id);
if (conv2d->kind == mojom::Conv2d::Kind::kDirect) {
CHECK(context_properties.data_type_limits.conv2d_bias.Supports(
bias_operand.descriptor));
} else {
CHECK(context_properties.data_type_limits.conv_transpose2d_bias.Supports(
bias_operand.descriptor));
}
const NodeOutput* bias_node_output =
GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id.value());
const auto& bias_tensor_desc = bias_node_output->GetTensorDesc();
const auto& bias_dims = bias_tensor_desc.GetDimensions();
std::vector<uint32_t> reshaped_bias_dims = {1, bias_dims[0], 1, 1};
reshaped_bias_tensor_desc =
TensorDesc(bias_tensor_desc.GetDataType(), bias_tensor_desc.GetFlags(),
std::move(reshaped_bias_dims));
const NodeOutput* reshaped_bias_node_output =
graph_builder.CreateNodeOutput(&bias_node_output->GetNode(),
reshaped_bias_tensor_desc.value());
inputs.push_back(reshaped_bias_node_output);
}
std::array<uint32_t, 2> strides = {conv2d->strides->height,
conv2d->strides->width};
std::array<uint32_t, 2> dilations = {conv2d->dilations->height,
conv2d->dilations->width};
std::array<uint32_t, 2> start_padding = {conv2d->padding->beginning->height,
conv2d->padding->beginning->width};
std::array<uint32_t, 2> end_padding = {conv2d->padding->ending->height,
conv2d->padding->ending->width};
std::array<uint32_t, 2> default_out_padding = {0, 0};
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
std::string label = conv2d->label;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
label = GetFusedOperatorLabel(label, "conv2d", *fusible_activation.value());
}
DML_CONVOLUTION_OPERATOR_DESC conv2d_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.FilterTensor = &filter_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = GetOptionalDmlTensorDescPtr(reshaped_bias_tensor_desc),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION,
.Direction = conv2d_direction,
.DimensionCount =
2u,
EndPadding, and OutputPadding arrays.*/
.Strides = strides.data(),
.Dilations = dilations.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.OutputPadding = default_out_padding.data(),
.GroupCount = conv2d->groups,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const GraphNode* conv2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_CONVOLUTION, &conv2d_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
conv2d_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForCumulativeSum(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::CumulativeSumPtr& cumulative_sum,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, cumulative_sum->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.cumulative_sum_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
OperandId output_id = cumulative_sum->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const uint32_t axis = cumulative_sum->axis;
DML_AXIS_DIRECTION axis_direction =
cumulative_sum->reversed
? DML_AXIS_DIRECTION::DML_AXIS_DIRECTION_DECREASING
: DML_AXIS_DIRECTION::DML_AXIS_DIRECTION_INCREASING;
DML_CUMULATIVE_SUMMATION_OPERATOR_DESC cumulative_sum_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = axis,
.AxisDirection = axis_direction,
.HasExclusiveSum = cumulative_sum->exclusive};
std::array<const NodeOutput*, 1> inputs = {input};
const std::string& label = cumulative_sum->label;
const GraphNode* cumulative_sum_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_CUMULATIVE_SUMMATION, &cumulative_sum_operator_desc, inputs,
label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
cumulative_sum_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
template <typename DML_OPERATOR_DESC, DML_OPERATOR_TYPE operator_type>
const GraphNode* CreateUnaryOperator(const TensorDesc& input_tensor,
const TensorDesc& output_tensor,
const NodeOutput* input,
GraphBuilderDml& graph_builder,
std::string_view label = "") {
DML_OPERATOR_DESC unary_operator_desc{
.InputTensor = &input_tensor.GetDMLTensorDesc(),
.OutputTensor = &output_tensor.GetDMLTensorDesc()};
std::array<const NodeOutput*, 1> inputs = {input};
return graph_builder.CreateOperatorNode(operator_type, &unary_operator_desc,
inputs, label);
}
template <typename DML_OPERATOR_DESC>
const GraphNode* CreateBinaryOperator(const TensorDesc& a_tensor,
const TensorDesc& b_tensor,
const TensorDesc& output_tensor,
GraphBuilderDml& graph_builder,
DML_OPERATOR_TYPE operator_type,
base::span<const NodeOutput*> inputs,
std::string_view label) {
DML_OPERATOR_DESC binary_operator_desc{
.ATensor = &a_tensor.GetDMLTensorDesc(),
.BTensor = &b_tensor.GetDMLTensorDesc(),
.OutputTensor = &output_tensor.GetDMLTensorDesc()};
return graph_builder.CreateOperatorNode(operator_type, &binary_operator_desc,
inputs, label);
}
const NodeOutput* AppendIdentityNode(
GraphBuilderDml& graph_builder,
const NodeOutput* input,
const TensorDesc* input_tensor_desc = nullptr) {
CHECK(input);
if (!input_tensor_desc) {
input_tensor_desc = &input->GetTensorDesc();
}
TensorDesc identity_tensor_desc(input_tensor_desc->GetDataType(),
DML_TENSOR_FLAG_NONE,
input_tensor_desc->GetDimensions());
const GraphNode* identity =
CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
*input_tensor_desc, identity_tensor_desc, input, graph_builder);
return graph_builder.CreateNodeOutput(identity,
std::move(identity_tensor_desc));
}
const NodeOutput* CreateReshapeNode(GraphBuilderDml& graph_builder,
const NodeOutput* input,
base::span<const uint32_t> new_shape) {
CHECK(input);
const auto& input_tensor_desc = input->GetTensorDesc();
const TensorDesc reshaped_input_tensor_desc(
input_tensor_desc.GetDataType(), input_tensor_desc.GetFlags(),
std::vector<uint32_t>(new_shape.begin(), new_shape.end()));
const NodeOutput* reshape_node =
AppendIdentityNode(graph_builder, input, &reshaped_input_tensor_desc);
return reshape_node;
}
const NodeOutput* CreateExpandNode(GraphBuilderDml& graph_builder,
const NodeOutput* input,
base::span<const uint32_t> new_shape,
std::string_view label) {
CHECK(input);
auto input_tensor_desc = input->GetTensorDesc();
if (input_tensor_desc.GetDimensions() != new_shape) {
input_tensor_desc.BroadcastTo(new_shape);
}
const auto expand_tensor_desc =
TensorDesc(input_tensor_desc.GetDataType(),
std::vector<uint32_t>(new_shape.begin(), new_shape.end()));
const GraphNode* identity_node =
CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
input_tensor_desc, expand_tensor_desc, input, graph_builder, label);
const NodeOutput* expand_node = graph_builder.CreateNodeOutput(
identity_node, std::move(expand_tensor_desc));
return expand_node;
}
base::expected<const NodeOutput*, mojom::ErrorPtr> BlockwiseExpandAlongAxis(
const NodeOutput* node,
GraphBuilderDml& graph_builder,
uint32_t axis,
uint32_t block_size,
std::string_view label) {
auto node_tensor_desc = node->GetTensorDesc();
auto input_dimensions = node_tensor_desc.GetDimensions();
std::array<uint32_t, 4> reshaped_input_dimensions;
base::CheckedNumeric<uint32_t> checked_pre_values =
std::accumulate(input_dimensions.begin(), input_dimensions.begin() + axis,
base::CheckedNumeric<uint32_t>(1), std::multiplies());
if (!checked_pre_values.IsValid()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"The shape values are too large for block-wise quantization emulation.",
label));
}
reshaped_input_dimensions[0] = checked_pre_values.ValueOrDie();
reshaped_input_dimensions[1] = input_dimensions[axis];
reshaped_input_dimensions[2] = 1;
base::CheckedNumeric<uint32_t> checked_after_values = std::accumulate(
input_dimensions.begin() + axis + 1, input_dimensions.end(),
base::CheckedNumeric<uint32_t>(1), std::multiplies());
if (!checked_after_values.IsValid()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"The shape values are too large for block-wise quantization emulation.",
label));
}
reshaped_input_dimensions[3] = checked_after_values.ValueOrDie();
const NodeOutput* reshape_node =
CreateReshapeNode(graph_builder, node, reshaped_input_dimensions);
auto expanded_new_operand_dimensions = reshaped_input_dimensions;
expanded_new_operand_dimensions[2] = block_size;
const NodeOutput* expand_reshaped_node = CreateExpandNode(
graph_builder, reshape_node, expanded_new_operand_dimensions, label);
auto output_dimensions = input_dimensions;
output_dimensions[axis] = block_size * input_dimensions[axis];
return CreateReshapeNode(graph_builder, expand_reshaped_node,
output_dimensions);
}
template <typename DML_OPERATOR_DESC, typename DequantizeOrQuantizeLinearPtr>
requires((std::is_same_v<DequantizeOrQuantizeLinearPtr,
mojom::DequantizeLinearPtr> ||
std::is_same_v<DequantizeOrQuantizeLinearPtr,
mojom::QuantizeLinearPtr>) &&
(std::is_same_v<DML_OPERATOR_DESC, DML_QUANTIZE_OPERATOR_DESC> ||
std::is_same_v<DML_OPERATOR_DESC, DML_DEQUANTIZE_OPERATOR_DESC> ||
std::is_same_v<DML_OPERATOR_DESC,
DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> ||
std::is_same_v<DML_OPERATOR_DESC,
DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>))
base::expected<void, mojom::ErrorPtr>
CreateOperatorNodeForDequantizeOrQuantizeLinear(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const DequantizeOrQuantizeLinearPtr& operation_ptr,
GraphBuilderDml& graph_builder,
DML_OPERATOR_TYPE operator_type,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, operation_ptr->input_operand_id);
auto input_tensor_desc = input->GetTensorDesc();
const NodeOutput* scale = GetNodeOutputForOperand(
id_to_node_output_map, operation_ptr->scale_operand_id);
auto scale_tensor_desc = scale->GetTensorDesc();
const NodeOutput* zero_point = GetNodeOutputForOperand(
id_to_node_output_map, operation_ptr->zero_point_operand_id);
auto zero_point_tensor_desc = zero_point->GetTensorDesc();
OperandId output_id = operation_ptr->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
auto adjusted_output_tensor_desc = output_tensor_desc;
const auto& output_dimensions = output_tensor_desc.GetDimensions();
const std::string& label = operation_ptr->label;
if constexpr (std::is_same_v<DML_OPERATOR_DESC, DML_QUANTIZE_OPERATOR_DESC> ||
std::is_same_v<DML_OPERATOR_DESC,
DML_DEQUANTIZE_OPERATOR_DESC>) {
const auto input_rank = input_tensor_desc.GetDimensions().size();
const auto scale_rank = scale_tensor_desc.GetDimensions().size();
CHECK_EQ(scale_rank, input_rank);
if (input_rank < 4) {
input_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
scale_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
zero_point_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
adjusted_output_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
}
} else {
const auto input_dimensions = input_tensor_desc.GetDimensions();
auto scale_dimensions = scale_tensor_desc.GetDimensions();
CHECK_EQ(input_dimensions.size(), scale_dimensions.size());
for (size_t index = 0; index < scale_dimensions.size(); index++) {
if (input_dimensions[index] != scale_dimensions[index] &&
input_dimensions[index] != 1 && scale_dimensions[index] != 1) {
uint32_t block_size = input_dimensions[index] / scale_dimensions[index];
uint32_t axis = index;
ASSIGN_OR_RETURN(scale,
BlockwiseExpandAlongAxis(scale, graph_builder, axis,
block_size, label));
scale_tensor_desc = scale->GetTensorDesc();
scale_dimensions = scale_tensor_desc.GetDimensions();
ASSIGN_OR_RETURN(zero_point,
BlockwiseExpandAlongAxis(zero_point, graph_builder,
axis, block_size, label));
zero_point_tensor_desc = zero_point->GetTensorDesc();
}
}
if (scale_tensor_desc.GetDimensions() != output_dimensions) {
scale_tensor_desc.BroadcastTo(output_dimensions);
zero_point_tensor_desc.BroadcastTo(output_dimensions);
}
}
if constexpr (std::is_same_v<DequantizeOrQuantizeLinearPtr,
mojom::DequantizeLinearPtr>) {
CHECK(context_properties.data_type_limits.dequantize_linear_input.data_types
.Has(DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
CHECK(context_properties.data_type_limits.dequantize_linear_scale.data_types
.Has(DmlDataTypeToOperand(scale_tensor_desc.GetDataType())));
CHECK(context_properties.data_type_limits.dequantize_linear_zero_point
.data_types.Has(
DmlDataTypeToOperand(zero_point_tensor_desc.GetDataType())));
} else {
CHECK(context_properties.data_type_limits.quantize_linear_input.data_types
.Has(DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
CHECK(context_properties.data_type_limits.quantize_linear_input.data_types
.Has(DmlDataTypeToOperand(scale_tensor_desc.GetDataType())));
CHECK(context_properties.data_type_limits.quantize_linear_zero_point
.data_types.Has(
DmlDataTypeToOperand(zero_point_tensor_desc.GetDataType())));
}
DML_OPERATOR_DESC operator_desc;
std::array<DML_TENSOR_DESC, 2> quantization_tensors = {
scale_tensor_desc.GetDMLTensorDesc(),
zero_point_tensor_desc.GetDMLTensorDesc()};
if constexpr (std::is_same_v<DML_OPERATOR_DESC, DML_QUANTIZE_OPERATOR_DESC> ||
std::is_same_v<DML_OPERATOR_DESC,
DML_DEQUANTIZE_OPERATOR_DESC>) {
operator_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.QuantizationType = DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT,
.QuantizationTensorCount = quantization_tensors.size(),
.QuantizationTensors = quantization_tensors.data(),
.OutputTensor = &adjusted_output_tensor_desc.GetDMLTensorDesc()};
} else {
operator_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.ScaleTensor = &scale_tensor_desc.GetDMLTensorDesc(),
.ZeroPointTensor = &zero_point_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
}
std::array<const NodeOutput*, 3> inputs = {input, scale, zero_point};
const GraphNode* operator_node = graph_builder.CreateOperatorNode(
operator_type, &operator_desc, inputs, label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
operator_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
return base::ok();
}
template <typename OperatorDesc,
DML_OPERATOR_TYPE operator_type,
typename Operation>
void CreateOperatorNodeForUnary(const std::vector<OperandPtr>& operands,
const Operation& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, operation->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
OperandId output_id = operation->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const GraphNode* unary_node =
CreateUnaryOperator<OperatorDesc, operator_type>(
input_tensor_desc, output_tensor_desc, input, graph_builder,
operation->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
unary_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForBinary(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const Operation* operation,
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& binary = operation->get_element_wise_binary();
const NodeOutput* input_a =
GetNodeOutputForOperand(id_to_node_output_map, binary->lhs_operand_id);
auto input_a_tensor_desc = input_a->GetTensorDesc();
const NodeOutput* input_b =
GetNodeOutputForOperand(id_to_node_output_map, binary->rhs_operand_id);
auto input_b_tensor_desc = input_b->GetTensorDesc();
OperandId output_id = binary->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
auto output_dimensions = output_tensor_desc.GetDimensions();
if (input_a_tensor_desc.GetDimensions() != output_dimensions) {
input_a_tensor_desc.BroadcastTo(output_dimensions);
}
if (input_b_tensor_desc.GetDimensions() != output_dimensions) {
input_b_tensor_desc.BroadcastTo(output_dimensions);
}
CHECK_EQ(input_a_tensor_desc.GetDataType(),
input_b_tensor_desc.GetDataType());
const OperandDataType input_data_type =
DmlDataTypeToOperand(input_a_tensor_desc.GetDataType());
std::string label = binary->label;
const GraphNode* binary_node = nullptr;
std::array<const NodeOutput*, 2> inputs = {input_a, input_b};
switch (binary->kind) {
case mojom::ElementWiseBinary::Kind::kAdd: {
CHECK(context_properties.data_type_limits.add_input.data_types.Has(
input_data_type));
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
if (fusible_activation) {
ActivationOperatorDesc activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
DML_OPERATOR_DESC activation_dml_desc =
activation_operator_desc.GetActivationDmlDesc();
DML_ELEMENT_WISE_ADD1_OPERATOR_DESC add1_operator_desc{
.ATensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.BTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.FusedActivation = &activation_dml_desc,
};
label =
GetFusedOperatorLabel(label, "add", *fusible_activation.value());
binary_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_ADD1, &add1_operator_desc, inputs, label);
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
}
else {
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_ADD, inputs, label);
}
break;
}
case mojom::ElementWiseBinary::Kind::kDiv: {
CHECK(context_properties.data_type_limits.div_input.data_types.Has(
input_data_type));
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_DIVIDE, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kMax: {
CHECK(context_properties.data_type_limits.max_input.data_types.Has(
input_data_type));
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_MAX_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MAX, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kMin: {
CHECK(context_properties.data_type_limits.min_input.data_types.Has(
input_data_type));
binary_node = CreateBinaryOperator<DML_ELEMENT_WISE_MIN_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MIN, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kMul: {
CHECK(context_properties.data_type_limits.mul_input.data_types.Has(
input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kSub: {
CHECK(context_properties.data_type_limits.sub_input.data_types.Has(
input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_SUBTRACT, inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kPow: {
CHECK(context_properties.data_type_limits.pow_input.data_types.Has(
input_data_type));
DML_ELEMENT_WISE_POW_OPERATOR_DESC element_wise_operator_desc{
.InputTensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.ExponentTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
binary_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_POW, &element_wise_operator_desc, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kEqual: {
CHECK(context_properties.data_type_limits.equal_input.data_types.Has(
input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kGreater: {
CHECK(context_properties.data_type_limits.greater_input.data_types.Has(
input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kGreaterOrEqual: {
CHECK(context_properties.data_type_limits.greater_or_equal_input
.data_types.Has(input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder,
DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kLesser: {
CHECK(context_properties.data_type_limits.lesser_input.data_types.Has(
input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kLesserOrEqual: {
CHECK(context_properties.data_type_limits.lesser_or_equal_input.data_types
.Has(input_data_type));
binary_node = CreateBinaryOperator<
DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL,
inputs, label);
break;
}
case mojom::ElementWiseBinary::Kind::kNotEqual: {
CHECK(context_properties.data_type_limits.not_equal_input.data_types.Has(
input_data_type));
const TensorDesc equal_output_tensor_desc =
TensorDesc(output_tensor_desc.GetDataType(), output_dimensions);
const GraphNode* equal_node =
CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc,
equal_output_tensor_desc, graph_builder,
DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, inputs, label);
const NodeOutput* equal_output =
graph_builder.CreateNodeOutput(equal_node, equal_output_tensor_desc);
binary_node =
CreateUnaryOperator<DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT>(
equal_output_tensor_desc, output_tensor_desc, equal_output,
graph_builder, label);
break;
}
case mojom::ElementWiseBinary::Kind::kLogicalAnd: {
CHECK(
context_properties.data_type_limits.logical_and_input.data_types.Has(
input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kLogicalOr: {
CHECK(context_properties.data_type_limits.logical_or_input.data_types.Has(
input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR, inputs,
label);
break;
}
case mojom::ElementWiseBinary::Kind::kLogicalXor: {
CHECK(
context_properties.data_type_limits.logical_xor_input.data_types.Has(
input_data_type));
binary_node =
CreateBinaryOperator<DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC>(
input_a_tensor_desc, input_b_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR, inputs,
label);
break;
}
}
const NodeOutput* output = graph_builder.CreateNodeOutput(
binary_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForPad(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::PadPtr& pad,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, pad->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.pad_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
OperandId output_id = pad->output_operand_id;
const auto& output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_PADDING_MODE padding_mode;
float padding_value = 0;
switch (pad->mode->which()) {
case mojom::PaddingMode::Tag::kConstant:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_CONSTANT;
padding_value = pad->mode->get_constant()->value.AsFloat32();
break;
case mojom::PaddingMode::Tag::kEdge:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_EDGE;
break;
case mojom::PaddingMode::Tag::kReflection:
padding_mode = DML_PADDING_MODE::DML_PADDING_MODE_REFLECTION;
break;
}
const auto& beginning_padding = pad->beginning_padding;
const auto& ending_padding = pad->ending_padding;
CHECK_EQ(beginning_padding.size(), ending_padding.size());
DML_PADDING_OPERATOR_DESC pad_operator_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.PaddingMode = padding_mode,
.PaddingValue = padding_value,
.DimensionCount = static_cast<uint32_t>(beginning_padding.size()),
.StartPadding = beginning_padding.data(),
.EndPadding = ending_padding.data()};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* pad_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_PADDING, &pad_operator_desc, {inputs}, pad->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(pad_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForPool2d(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::Pool2dPtr& pool2d,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, pool2d->input_operand_id);
auto input_tensor_desc = input->GetTensorDesc();
OperandId output_id = pool2d->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
std::array<uint32_t, 2> strides = {pool2d->strides->height,
pool2d->strides->width};
std::array<uint32_t, 2> dilations = {pool2d->dilations->height,
pool2d->dilations->width};
std::array<uint32_t, 2> window_dimensions = {
pool2d->window_dimensions->height, pool2d->window_dimensions->width};
std::array<uint32_t, 2> start_padding = {pool2d->padding->beginning->height,
pool2d->padding->beginning->width};
std::array<uint32_t, 2> end_padding = {pool2d->padding->ending->height,
pool2d->padding->ending->width};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* pool2d_node = nullptr;
const std::string& label = pool2d->label;
switch (pool2d->kind) {
case mojom::Pool2d::Kind::kAveragePool2d: {
CHECK(context_properties.data_type_limits.average_pool2d_input.data_types
.Has(DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
if (dilations[0] != 1 || dilations[1] != 1) {
return base::unexpected(CreateError(
mojom::Error::Code::kNotSupportedError,
"Dilations are not supported for average pooling operator.",
label));
}
DML_AVERAGE_POOLING_OPERATOR_DESC average_pooling_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.IncludePadding = false};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_AVERAGE_POOLING, &average_pooling_desc, inputs, label);
break;
}
case mojom::Pool2d::Kind::kL2Pool2d: {
CHECK(context_properties.data_type_limits.l2_pool2d_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
DML_LP_POOLING_OPERATOR_DESC l2_pooling_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.P = 2};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_LP_POOLING, &l2_pooling_desc, inputs, label);
break;
}
case mojom::Pool2d::Kind::kMaxPool2d: {
CHECK(context_properties.data_type_limits.max_pool2d_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
if (dilations[0] == 1 && dilations[1] == 1) {
DML_MAX_POOLING_OPERATOR_DESC max_pooling_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data()};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_MAX_POOLING, &max_pooling_desc, inputs, label);
} else {
DML_MAX_POOLING2_OPERATOR_DESC max_pooling2_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.OutputIndicesTensor = nullptr,
.DimensionCount =
base::checked_cast<uint32_t>(window_dimensions.size()),
.Strides = strides.data(),
.WindowSize = window_dimensions.data(),
.StartPadding = start_padding.data(),
.EndPadding = end_padding.data(),
.Dilations = dilations.data()};
pool2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_MAX_POOLING2, &max_pooling2_desc, inputs, label);
}
break;
}
default:
LOG(ERROR) << "[WebNN] Invalid Pool2d operator type";
NOTREACHED();
}
const NodeOutput* output = graph_builder.CreateNodeOutput(
pool2d_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
void CreateOperatorNodeForPrelu(const ContextProperties context_properties,
const std::vector<OperandPtr>& operands,
const mojom::PreluPtr& prelu,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, prelu->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.prelu_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const NodeOutput* slope =
GetNodeOutputForOperand(id_to_node_output_map, prelu->slope_operand_id);
auto slope_tensor_desc = slope->GetTensorDesc();
CHECK_EQ(input_tensor_desc.GetDataType(), slope_tensor_desc.GetDataType());
OperandId output_id = prelu->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const auto& output_dimensions = output_tensor_desc.GetDimensions();
if (slope_tensor_desc.GetDimensions() != output_dimensions) {
slope_tensor_desc.BroadcastTo(output_dimensions);
}
DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC prelu_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.SlopeTensor = &slope_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
const std::string& label = prelu->label;
std::array<const NodeOutput*, 2> inputs = {input, slope};
const GraphNode* prelu_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, &prelu_desc, inputs, label);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(prelu_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForScatterElements(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ScatterElementsPtr& scatter_elements,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, scatter_elements->input_operand_id);
TensorDesc input_tensor_desc = input->GetTensorDesc();
CHECK(
context_properties.data_type_limits.scatter_elements_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const NodeOutput* indices = GetNodeOutputForOperand(
id_to_node_output_map, scatter_elements->indices_operand_id);
TensorDesc indices_tensor_desc = indices->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_elements_indices.data_types
.Has(DmlDataTypeToOperand(indices_tensor_desc.GetDataType())));
const NodeOutput* updates = GetNodeOutputForOperand(
id_to_node_output_map, scatter_elements->updates_operand_id);
TensorDesc updates_tensor_desc = updates->GetTensorDesc();
CHECK(
context_properties.data_type_limits.scatter_elements_input.data_types.Has(
DmlDataTypeToOperand(updates_tensor_desc.GetDataType())));
OperandId output_id = scatter_elements->output_operand_id;
const TensorDesc output_tensor_desc =
CreateOutputTensorDesc(operands, output_id);
DML_SCATTER_OPERATOR_DESC scatter_elements_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.UpdatesTensor = &updates_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = scatter_elements->axis};
std::array<const NodeOutput*, 3> inputs = {input, indices, updates};
const GraphNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SCATTER, &scatter_elements_desc, inputs,
scatter_elements->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForScatterND(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ScatterNDPtr& scatter_nd,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, scatter_nd->input_operand_id);
TensorDesc input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_nd_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const NodeOutput* indices = GetNodeOutputForOperand(
id_to_node_output_map, scatter_nd->indices_operand_id);
TensorDesc indices_tensor_desc = indices->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_nd_indices.data_types.Has(
DmlDataTypeToOperand(indices_tensor_desc.GetDataType())));
const NodeOutput* updates = GetNodeOutputForOperand(
id_to_node_output_map, scatter_nd->updates_operand_id);
TensorDesc updates_tensor_desc = updates->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_nd_updates.data_types.Has(
DmlDataTypeToOperand(updates_tensor_desc.GetDataType())));
OperandId output_id = scatter_nd->output_operand_id;
const TensorDesc original_output_tensor_desc =
CreateOutputTensorDesc(operands, output_id);
size_t input_rank = input_tensor_desc.GetDimensions().size();
size_t indices_rank = indices_tensor_desc.GetDimensions().size();
size_t updates_rank = updates_tensor_desc.GetDimensions().size();
size_t output_rank = original_output_tensor_desc.GetDimensions().size();
size_t maximum_rank =
std::max({input_rank, indices_rank, updates_rank, output_rank});
input_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
indices_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
updates_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
TensorDesc output_tensor_desc = original_output_tensor_desc;
output_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
DML_SCATTER_ND_OPERATOR_DESC scatter_nd_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.UpdatesTensor = &updates_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.InputDimensionCount = base::checked_cast<uint32_t>(input_rank),
.IndicesDimensionCount = base::checked_cast<uint32_t>(indices_rank)};
std::array<const NodeOutput*, 3> inputs = {input, indices, updates};
const GraphNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SCATTER_ND, &scatter_nd_desc, inputs, scatter_nd->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
node, std::move(original_output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForSlice(const std::vector<OperandPtr>& operands,
const mojom::SlicePtr& slice,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, slice->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const std::vector<uint32_t>& input_dimensions =
input_tensor_desc.GetDimensions();
const size_t input_rank = input_dimensions.size();
const TensorDesc& output_tensor_desc =
CreateOutputTensorDesc(operands, slice->output_operand_id);
const std::vector<uint32_t>& output_dimensions =
output_tensor_desc.GetDimensions();
CHECK_EQ(input_rank, output_dimensions.size());
CHECK_EQ(input_rank, slice->ranges.size());
base::FixedArray<uint32_t> starts(input_rank);
base::FixedArray<uint32_t> sizes(input_rank);
base::FixedArray<uint32_t> strides(input_rank);
for (size_t i = 0; i < input_rank; ++i) {
starts[i] = slice->ranges[i].start;
sizes[i] = output_dimensions[i];
strides[i] = slice->ranges[i].stride;
}
DML_SLICE_OPERATOR_DESC slice_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount = static_cast<UINT>(input_dimensions.size()),
.Offsets = starts.data(),
.Sizes = sizes.data(),
.Strides = strides.data(),
};
std::array<const NodeOutput*, 1> input_node_output = {input};
const GraphNode* slice_node =
graph_builder.CreateOperatorNode(DML_OPERATOR_SLICE, &slice_operator_desc,
input_node_output, slice->label);
const auto* slice_output =
graph_builder.CreateNodeOutput(slice_node, std::move(output_tensor_desc));
id_to_node_output_map[slice->output_operand_id] = std::move(slice_output);
}
void CreateOperatorNodeForSplit(const std::vector<OperandPtr>& operands,
const mojom::SplitPtr& split,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, split->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
base::FixedArray<TensorDesc> output_tensor_desc(
split->output_operand_ids.size());
base::FixedArray<DML_TENSOR_DESC> output_tensor_desc_dml(
split->output_operand_ids.size());
for (size_t i = 0; i < split->output_operand_ids.size(); ++i) {
output_tensor_desc[i] =
CreateOutputTensorDesc(operands, split->output_operand_ids[i]);
output_tensor_desc_dml[i] = output_tensor_desc[i].GetDMLTensorDesc();
}
auto output_count =
base::checked_cast<uint32_t>(output_tensor_desc_dml.size());
DML_SPLIT_OPERATOR_DESC split_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputCount = output_count,
.OutputTensors = output_tensor_desc_dml.data(),
.Axis = split->axis};
const std::string& label = split->label;
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* split_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SPLIT, &split_desc, inputs, label);
for (uint32_t i = 0; i < output_count; ++i) {
OperandId output_id = split->output_operand_ids[i];
const auto* output = graph_builder.CreateNodeOutput(
split_node, std::move(output_tensor_desc[i]), i);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
}
void CreateOperatorNodeForNeg(const std::vector<OperandPtr>& operands,
const mojom::ElementWiseUnaryPtr& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, operation->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const OperandId output_id = operation->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_SCALE_BIAS scale_bias{.Scale = -1.f, .Bias = 0.f};
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC identity_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ScaleBias = &scale_bias};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* identity_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_IDENTITY, &identity_operator_desc, inputs,
operation->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
identity_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForRoundEven(const std::vector<OperandPtr>& operands,
const mojom::ElementWiseUnaryPtr& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, operation->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const OperandId output_id = operation->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_ELEMENT_WISE_ROUND_OPERATOR_DESC round_even_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.RoundingMode =
DML_ROUNDING_MODE::DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* round_even_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_ROUND, &round_even_operator_desc, inputs,
operation->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
round_even_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForElementWiseUnary(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ElementWiseUnaryPtr& operation,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const OperandDataType input_data_type =
DmlDataTypeToOperand(GetNodeOutputForOperand(id_to_node_output_map,
operation->input_operand_id)
->GetTensorDesc()
.GetDataType());
switch (operation->kind) {
case mojom::ElementWiseUnary::Kind::kAbs: {
CHECK(context_properties.data_type_limits.abs_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_ABS_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_ABS>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kCast: {
CHECK(context_properties.data_type_limits.cast_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_CAST_OPERATOR_DESC,
DML_OPERATOR_CAST>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kCeil: {
CHECK(context_properties.data_type_limits.ceil_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_CEIL_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_CEIL>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kCos: {
CHECK(context_properties.data_type_limits.cos_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_COS_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_COS>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kErf: {
CHECK(context_properties.data_type_limits.erf_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_ERF_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_ERF>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kExp: {
CHECK(context_properties.data_type_limits.exp_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_EXP_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_EXP>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kFloor: {
CHECK(context_properties.data_type_limits.floor_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_FLOOR>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kIdentity: {
CHECK(context_properties.data_type_limits.identity_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kLog: {
CHECK(context_properties.data_type_limits.log_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_LOG_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_LOG>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kIsNaN: {
CHECK(context_properties.data_type_limits.is_nan_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IS_NAN>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kIsInfinite: {
CHECK(
context_properties.data_type_limits.is_infinite_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<
DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IS_INFINITY>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kLogicalNot: {
CHECK(
context_properties.data_type_limits.logical_not_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<
DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kNeg: {
CHECK(context_properties.data_type_limits.neg_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForNeg(operands, operation, graph_builder,
id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kReciprocal: {
CHECK(context_properties.data_type_limits.reciprocal_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_RECIP_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_RECIP>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kRoundEven: {
CHECK(context_properties.data_type_limits.round_even_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForRoundEven(operands, operation, graph_builder,
id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kSign: {
CHECK(context_properties.data_type_limits.sign_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SIGN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_SIGN>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kSin: {
CHECK(context_properties.data_type_limits.sin_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SIN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_SIN>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kSqrt: {
CHECK(context_properties.data_type_limits.sqrt_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_SQRT_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_SQRT>(
operands, operation, graph_builder, id_to_node_output_map);
}
case mojom::ElementWiseUnary::Kind::kTan: {
CHECK(context_properties.data_type_limits.tan_input.data_types.Has(
input_data_type));
return CreateOperatorNodeForUnary<DML_ELEMENT_WISE_TAN_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_TAN>(
operands, operation, graph_builder, id_to_node_output_map);
}
}
}
void CreateOperatorNodeForResample2d(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::Resample2dPtr& resample2d,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, resample2d->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.resample2d_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
OperandId output_id = resample2d->output_operand_id;
const auto& output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const auto& input_dimensions = input_tensor_desc.GetDimensions();
const auto& output_dimensions = output_tensor_desc.GetDimensions();
size_t input_rank = input_dimensions.size();
CHECK_EQ(input_rank, output_dimensions.size());
base::FixedArray<float> full_scales(input_rank, 1);
const auto& scales = resample2d->scales;
const auto& axes = resample2d->axes;
if (scales) {
for (size_t i = 0; i < axes.size(); ++i) {
auto axis = axes[i];
CHECK_LT(axis, full_scales.size());
full_scales[axis] = scales.value()[i];
}
} else {
for (size_t i = 0; i < input_rank; ++i) {
full_scales[i] =
base::checked_cast<float>(output_dimensions[i]) / input_dimensions[i];
}
}
DML_INTERPOLATION_MODE mode;
switch (resample2d->mode) {
case mojom::Resample2d::InterpolationMode::kNearestNeighbor:
mode = DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR;
break;
case mojom::Resample2d::InterpolationMode::kLinear:
mode = DML_INTERPOLATION_MODE_LINEAR;
break;
}
DML_RESAMPLE_OPERATOR_DESC resample2d_operator_desc = {
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.InterpolationMode = mode,
.ScaleCount = static_cast<uint32_t>(full_scales.size()),
.Scales = full_scales.data()};
const std::string& label = resample2d->label;
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* resample2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_RESAMPLE, &resample2d_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
resample2d_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForReduce(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ReducePtr& reduce,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, reduce->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CheckInputDataTypeForReduce(
context_properties.data_type_limits, reduce->kind,
DmlDataTypeToOperand(input_tensor_desc.GetDataType()));
OperandId output_id = reduce->output_operand_id;
const auto& output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const auto& axes = reduce->axes;
std::vector<uint32_t> output_dimensions = input_tensor_desc.GetDimensions();
for (uint32_t axis : axes) {
CHECK_LT(axis, output_dimensions.size());
output_dimensions[axis] = 1u;
}
TensorDesc new_output_tensor_desc(output_tensor_desc.GetDataType(),
output_dimensions);
std::array<const NodeOutput*, 1> inputs = {input};
DML_REDUCE_OPERATOR_DESC operator_desc = {};
operator_desc.Function = MapReduceKindToReduceFuntion(reduce->kind);
operator_desc.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
operator_desc.OutputTensor = &new_output_tensor_desc.GetDMLTensorDesc(),
operator_desc.AxisCount = static_cast<uint32_t>(axes.size());
operator_desc.Axes = axes.data();
const GraphNode* reduce_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_REDUCE, &operator_desc, inputs, reduce->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(reduce_node, output_tensor_desc);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForReshape(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ReshapePtr& reshape,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, reshape->input_operand_id);
CHECK(context_properties.data_type_limits.reshape_input.data_types.Has(
DmlDataTypeToOperand(input->GetTensorDesc().GetDataType())));
OperandId output_id = reshape->output_operand_id;
const Operand& output_operand = GetOperand(operands, output_id);
base::span<const uint32_t> new_shape = output_operand.descriptor.shape();
const NodeOutput* output = CreateReshapeNode(graph_builder, input, new_shape);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForReverse(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::Reverse& reverse,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, reverse.input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const size_t input_rank = input_tensor_desc.GetDimensions().size();
const OperandId output_id = reverse.output_operand_id;
const TensorDesc output_tensor_desc =
CreateOutputTensorDesc(operands, output_id);
base::FixedArray<uint32_t> starts(input_rank, 0);
base::FixedArray<int32_t> strides(input_rank, 1);
for (uint32_t axis : reverse.axes) {
CHECK_LT(axis, input_rank);
strides[axis] = -1;
}
DML_SLICE1_OPERATOR_DESC reverse_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount = base::checked_cast<uint32_t>(input_rank),
.InputWindowOffsets = starts.data(),
.InputWindowSizes = input_tensor_desc.GetDimensions().data(),
.InputWindowStrides = strides.data()};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SLICE1, &reverse_desc, inputs, reverse.label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForElu(const std::vector<OperandPtr>& operands,
const mojom::EluPtr& elu,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, elu->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
OperandId output_id = elu->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_ACTIVATION_ELU_OPERATOR_DESC elu_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = elu->alpha};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* elu_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_ELU, &elu_desc, inputs, elu->label);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(elu_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForExpand(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::ExpandPtr& expand,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, expand->input_operand_id);
auto input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.expand_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const OperandId output_id = expand->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const auto& output_dimensions = output_tensor_desc.GetDimensions();
const NodeOutput* node_output =
CreateExpandNode(graph_builder, input, output_dimensions, expand->label);
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGather(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::GatherPtr& gather,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
OperandId input_id = gather->input_operand_id;
CHECK(context_properties.data_type_limits.gather_input.Supports(
GetOperand(operands, input_id).descriptor));
OperandId indices_id = gather->indices_operand_id;
CHECK(context_properties.data_type_limits.gather_indices.Supports(
GetOperand(operands, indices_id).descriptor));
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_id);
auto input_tensor_desc = input->GetTensorDesc();
const NodeOutput* indices =
GetNodeOutputForOperand(id_to_node_output_map, indices_id);
auto indices_tensor_desc = indices->GetTensorDesc();
OperandId output_id = gather->output_operand_id;
const auto original_output_tensor_desc =
CreateOutputTensorDesc(operands, output_id);
auto output_tensor_desc = original_output_tensor_desc;
size_t input_rank = input_tensor_desc.GetDimensions().size();
size_t output_rank = output_tensor_desc.GetDimensions().size();
size_t expanded_rank = std::max(input_rank, output_rank);
size_t indices_rank = indices_tensor_desc.GetDimensions().size();
input_tensor_desc.EnsureMinimumRank(expanded_rank,
TensorDesc::Alignment::kTrailing);
indices_tensor_desc.EnsureMinimumRank(expanded_rank,
TensorDesc::Alignment::kTrailing);
uint32_t axis = gather->axis;
if (output_rank < input_rank) {
CHECK_EQ(indices_rank, 1u);
CHECK_EQ(output_rank, input_rank - 1);
auto output_dimensions = input_tensor_desc.GetDimensions();
CHECK_LT(axis, output_dimensions.size());
output_dimensions[axis] = 1;
output_tensor_desc = TensorDesc(output_tensor_desc.GetDataType(),
std::move(output_dimensions));
}
auto expanded_axis = base::MakeCheckedNum(expanded_rank) - input_rank +
base::checked_cast<size_t>(axis);
const std::string& label = gather->label;
if (!expanded_axis.AssignIfValid<uint32_t>(&axis)) {
return base::unexpected(
CreateError(mojom::Error::Code::kUnknownError,
"The axis of gather operator is too large.", label));
}
DML_GATHER_OPERATOR_DESC gather_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = axis,
.IndexDimensions = base::checked_cast<uint32_t>(indices_rank)};
std::array<const NodeOutput*, 2> inputs = {input, indices};
const GraphNode* gather_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GATHER, &gather_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
gather_node, std::move(original_output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
void CreateOperatorNodeForGatherElements(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::GatherElementsPtr& gather_elements,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
OperandId input_id = gather_elements->input_operand_id;
CHECK(context_properties.data_type_limits.gather_elements_input.Supports(
GetOperand(operands, input_id).descriptor));
OperandId indices_id = gather_elements->indices_operand_id;
CHECK(context_properties.data_type_limits.gather_elements_indices.Supports(
GetOperand(operands, indices_id).descriptor));
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const NodeOutput* indices =
GetNodeOutputForOperand(id_to_node_output_map, indices_id);
const TensorDesc& indices_tensor_desc = indices->GetTensorDesc();
OperandId output_id = gather_elements->output_operand_id;
const TensorDesc output_tensor_desc =
CreateOutputTensorDesc(operands, output_id);
DML_GATHER_ELEMENTS_OPERATOR_DESC gather_elements_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = gather_elements->axis};
std::array<const NodeOutput*, 2> inputs = {input, indices};
const GraphNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GATHER_ELEMENTS, &gather_elements_desc, inputs,
gather_elements->label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForGatherND(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::GatherNDPtr& gather_nd,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
OperandId input_id = gather_nd->input_operand_id;
CHECK(context_properties.data_type_limits.gather_nd_input.Supports(
GetOperand(operands, input_id).descriptor));
OperandId indices_id = gather_nd->indices_operand_id;
CHECK(context_properties.data_type_limits.gather_nd_indices.Supports(
GetOperand(operands, indices_id).descriptor));
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_id);
TensorDesc input_tensor_desc = input->GetTensorDesc();
const NodeOutput* indices =
GetNodeOutputForOperand(id_to_node_output_map, indices_id);
TensorDesc indices_tensor_desc = indices->GetTensorDesc();
OperandId output_id = gather_nd->output_operand_id;
const TensorDesc original_output_tensor_desc =
CreateOutputTensorDesc(operands, output_id);
size_t input_rank = input_tensor_desc.GetDimensions().size();
size_t indices_rank = indices_tensor_desc.GetDimensions().size();
size_t output_rank = original_output_tensor_desc.GetDimensions().size();
size_t maximum_rank = std::max({input_rank, indices_rank, output_rank});
input_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
indices_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
TensorDesc output_tensor_desc = original_output_tensor_desc;
output_tensor_desc.EnsureMinimumRank(maximum_rank,
TensorDesc::Alignment::kTrailing);
DML_GATHER_ND_OPERATOR_DESC gather_nd_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.InputDimensionCount = base::checked_cast<uint32_t>(input_rank),
.IndicesDimensionCount = base::checked_cast<uint32_t>(indices_rank)};
std::array<const NodeOutput*, 2> inputs = {input, indices};
const GraphNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GATHER_ND, &gather_nd_desc, inputs, gather_nd->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
node, std::move(original_output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void CreateOperatorNodeForGelu(
const ContextProperties& context_properties,
Adapter* adapter,
const mojom::GeluPtr& gelu,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map) {
const auto& operands = graph_info->operands;
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_1)) {
return CreateOperatorNodeForUnary<DML_ACTIVATION_GELU_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_GELU>(
operands, gelu, graph_builder, id_to_node_output_map);
}
const Operand& input_operand = GetOperand(operands, gelu->input_operand_id);
const OperandDataType data_type = input_operand.descriptor.data_type();
OperandId constant_for_sqrt_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, data_type, 1,
2.0);
CreateConstantNode(adapter, constant_for_sqrt_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
const NodeOutput* constant_for_sqrt_output = GetNodeOutputForOperand(
id_to_node_output_map, constant_for_sqrt_operand_id);
const TensorDesc sqrt_output_tensor_desc =
TensorDesc(GetTensorDataType(data_type), {1});
DML_ELEMENT_WISE_SQRT_OPERATOR_DESC sqrt_operator_desc{
.InputTensor =
&constant_for_sqrt_output->GetTensorDesc().GetDMLTensorDesc(),
.OutputTensor = &sqrt_output_tensor_desc.GetDMLTensorDesc(),
};
const std::string& label = gelu->label;
std::array<const NodeOutput*, 1> sqrt_inputs = {constant_for_sqrt_output};
const GraphNode* sqrt_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_SQRT, &sqrt_operator_desc, sqrt_inputs, label);
const NodeOutput* sqrt_output =
graph_builder.CreateNodeOutput(sqrt_node, sqrt_output_tensor_desc);
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, gelu->input_operand_id);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
const std::vector<uint32_t>& input_dimensions =
input_tensor_desc.GetDimensions();
TensorDesc div_divisor_tensor_desc = sqrt_output->GetTensorDesc();
div_divisor_tensor_desc.BroadcastTo(input_dimensions);
OperandId output_id = gelu->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const TensorDesc& div_output_tensor_desc = output_tensor_desc;
std::array<const NodeOutput*, 2> div_inputs = {input, sqrt_output};
const GraphNode* div_node =
CreateBinaryOperator<DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC>(
input_tensor_desc, div_divisor_tensor_desc, div_output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_DIVIDE, div_inputs, label);
const NodeOutput* div_output =
graph_builder.CreateNodeOutput(div_node, div_output_tensor_desc);
const TensorDesc& erf_output_tensor_desc = output_tensor_desc;
DML_ELEMENT_WISE_ERF_OPERATOR_DESC erf_operator_desc{
.InputTensor = &div_output->GetTensorDesc().GetDMLTensorDesc(),
.OutputTensor = &erf_output_tensor_desc.GetDMLTensorDesc(),
};
std::array<const NodeOutput*, 1> erf_inputs = {div_output};
const GraphNode* erf_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_ERF, &erf_operator_desc, erf_inputs, label);
const NodeOutput* erf_output =
graph_builder.CreateNodeOutput(erf_node, erf_output_tensor_desc);
OperandId constant_for_add_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, data_type, 1,
1.0);
CreateConstantNode(adapter, constant_for_add_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
const NodeOutput* constant_for_add_output = GetNodeOutputForOperand(
id_to_node_output_map, constant_for_add_operand_id);
const TensorDesc& add_output_tensor_desc = output_tensor_desc;
TensorDesc constant_for_add_tensor_desc =
constant_for_add_output->GetTensorDesc();
constant_for_add_tensor_desc.BroadcastTo(input_dimensions);
std::array<const NodeOutput*, 2> add_inputs = {erf_output,
constant_for_add_output};
const GraphNode* add_node =
CreateBinaryOperator<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>(
erf_output_tensor_desc, constant_for_add_tensor_desc,
add_output_tensor_desc, graph_builder, DML_OPERATOR_ELEMENT_WISE_ADD,
add_inputs, label);
const NodeOutput* add_output =
graph_builder.CreateNodeOutput(add_node, add_output_tensor_desc);
const TensorDesc& second_mul_output_tensor_desc = output_tensor_desc;
std::array<const NodeOutput*, 2> second_mul_inputs = {input, add_output};
const GraphNode* second_mul_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
input_tensor_desc, add_output_tensor_desc,
second_mul_output_tensor_desc, graph_builder,
DML_OPERATOR_ELEMENT_WISE_MULTIPLY, second_mul_inputs, label);
const NodeOutput* second_mul_output = graph_builder.CreateNodeOutput(
second_mul_node, second_mul_output_tensor_desc);
OperandId constant_for_mul_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, data_type, 1,
0.5);
CreateConstantNode(adapter, constant_for_mul_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
const NodeOutput* constant_for_mul_output = GetNodeOutputForOperand(
id_to_node_output_map, constant_for_mul_operand_id);
TensorDesc constant_for_mul_tensor_desc =
constant_for_mul_output->GetTensorDesc();
constant_for_mul_tensor_desc.BroadcastTo(input_dimensions);
std::array<const NodeOutput*, 2> mul_constant_inputs = {
second_mul_output, constant_for_mul_output};
const GraphNode* mul_constant_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
second_mul_output_tensor_desc, constant_for_mul_tensor_desc,
output_tensor_desc, graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY,
mul_constant_inputs, label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
mul_constant_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForGemm(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const Operation* operation,
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& gemm = operation->get_gemm();
OperandId input_a_id = gemm->a_operand_id;
OperandId input_b_id = gemm->b_operand_id;
CHECK(context_properties.data_type_limits.gemm_a.SupportsAll(
{GetOperand(operands, input_a_id).descriptor,
GetOperand(operands, input_b_id).descriptor}));
const NodeOutput* input_a_node_output =
GetNodeOutputForOperand(id_to_node_output_map, input_a_id);
auto input_a_tensor_desc = input_a_node_output->GetTensorDesc();
const NodeOutput* input_b_node_output =
GetNodeOutputForOperand(id_to_node_output_map, input_b_id);
auto input_b_tensor_desc = input_b_node_output->GetTensorDesc();
std::vector<const NodeOutput*> inputs{input_a_node_output,
input_b_node_output};
OperandId output_id = gemm->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
std::optional<TensorDesc> input_c_tensor_desc;
auto& c_operand_id = gemm->c_operand_id;
if (c_operand_id) {
OperandId input_c_id = c_operand_id.value();
CHECK(context_properties.data_type_limits.gemm_c.Supports(
GetOperand(operands, input_c_id).descriptor));
const NodeOutput* input_c_node_output =
GetNodeOutputForOperand(id_to_node_output_map, input_c_id);
input_c_tensor_desc = input_c_node_output->GetTensorDesc();
inputs.push_back(input_c_node_output);
auto output_dimensions = output_tensor_desc.GetDimensions();
if (input_c_tensor_desc->GetDimensions() != output_dimensions) {
input_c_tensor_desc->BroadcastTo(output_dimensions);
}
}
input_a_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
input_b_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
if (input_c_tensor_desc) {
input_c_tensor_desc->EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
}
auto expanded_output_tensor_desc = output_tensor_desc;
expanded_output_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
std::string label = gemm->label;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
label = GetFusedOperatorLabel(label, "gemm", *fusible_activation.value());
}
DML_GEMM_OPERATOR_DESC gemm_operator_desc{
.ATensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.BTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.CTensor = GetOptionalDmlTensorDescPtr(input_c_tensor_desc),
.OutputTensor = &expanded_output_tensor_desc.GetDMLTensorDesc(),
.TransA = (gemm->a_transpose) ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.TransB = (gemm->b_transpose) ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.Alpha = gemm->alpha,
.Beta = gemm->beta,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const GraphNode* gemm_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GEMM, &gemm_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
gemm_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
const NodeOutput* AppendIdentityToConstantOperand(
GraphBuilderDml& graph_builder,
const NodeOutput* input) {
CHECK(input);
if (!(input->GetTensorDesc().GetFlags() & DML_TENSOR_FLAG_OWNED_BY_DML)) {
return input;
}
return AppendIdentityNode(graph_builder, input);
}
template <typename GruType>
requires(std::is_same_v<GruType, mojom::GruPtr> ||
std::is_same_v<GruType, mojom::GruCellPtr>)
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForGru(
Adapter* adapter,
const ContextProperties& context_properties,
const GruType& gru,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map) {
const auto& operands = graph_info->operands;
mojom::Operation::Tag op_tag;
std::optional<OperandId> initial_hidden_state_operand_id;
bool return_sequence;
mojom::RecurrentNetworkDirection direction;
if constexpr (std::is_same_v<GruType, mojom::GruPtr>) {
CHECK(context_properties.data_type_limits.gru_input.SupportsAll(
{GetOperand(operands, gru->input_operand_id).descriptor,
GetOperand(operands, gru->weight_operand_id).descriptor,
GetOperand(operands, gru->recurrent_weight_operand_id).descriptor}));
op_tag = mojom::Operation::Tag::kGru;
initial_hidden_state_operand_id = gru->initial_hidden_state_operand_id;
return_sequence = gru->return_sequence;
direction = gru->direction;
} else {
CHECK(context_properties.data_type_limits.gru_cell_input.SupportsAll(
{GetOperand(operands, gru->input_operand_id).descriptor,
GetOperand(operands, gru->weight_operand_id).descriptor,
GetOperand(operands, gru->recurrent_weight_operand_id).descriptor}));
op_tag = mojom::Operation::Tag::kGruCell;
initial_hidden_state_operand_id = gru->hidden_state_operand_id;
return_sequence = false;
direction = mojom::RecurrentNetworkDirection::kForward;
}
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, gru->input_operand_id);
input = AppendIdentityToConstantOperand(graph_builder, input);
TensorDesc input_tensor_desc = input->GetTensorDesc();
input_tensor_desc.EnsureMinimumRank(4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* weight =
GetNodeOutputForOperand(id_to_node_output_map, gru->weight_operand_id);
weight = AppendIdentityToConstantOperand(graph_builder, weight);
TensorDesc weight_tensor_desc = weight->GetTensorDesc();
weight_tensor_desc.EnsureMinimumRank( 4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* recurrent_weight = GetNodeOutputForOperand(
id_to_node_output_map, gru->recurrent_weight_operand_id);
recurrent_weight =
AppendIdentityToConstantOperand(graph_builder, recurrent_weight);
TensorDesc recurrent_weight_tensor_desc = recurrent_weight->GetTensorDesc();
recurrent_weight_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
std::vector<const NodeOutput*> inputs{input, weight, recurrent_weight};
const Operand& input_operand = GetOperand(operands, gru->input_operand_id);
const OperandDataType data_type = input_operand.descriptor.data_type();
const std::string& label = gru->label;
std::optional<TensorDesc> concatenated_bias_tensor_desc;
if (!gru->bias_operand_id.has_value() &&
!gru->recurrent_bias_operand_id.has_value()) {
inputs.push_back(nullptr);
} else {
std::optional<const NodeOutput*> zero_bias;
if (!gru->bias_operand_id.has_value() ||
!gru->recurrent_bias_operand_id.has_value()) {
OperandId zero_bias_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, data_type,
1,
0);
CreateConstantNode(adapter, zero_bias_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
zero_bias =
GetNodeOutputForOperand(id_to_node_output_map, zero_bias_operand_id);
}
const NodeOutput* bias =
gru->bias_operand_id.has_value()
? GetOptionalNodeOutputForOperand(id_to_node_output_map,
gru->bias_operand_id)
: zero_bias.value();
const NodeOutput* recurrent_bias =
gru->recurrent_bias_operand_id.has_value()
? GetOptionalNodeOutputForOperand(id_to_node_output_map,
gru->recurrent_bias_operand_id)
: zero_bias.value();
const uint32_t num_directions =
direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1;
uint32_t hidden_size = gru->hidden_size;
auto checked_three_times_hidden_size =
base::MakeCheckedNum(hidden_size) * 3;
CHECK(checked_three_times_hidden_size.IsValid());
const std::array<uint32_t, 4> half_bias_dimensions = {
1, 1, num_directions, checked_three_times_hidden_size.ValueOrDie()};
TensorDesc bias_tensor_desc = bias->GetTensorDesc();
bias_tensor_desc.BroadcastTo(half_bias_dimensions);
TensorDesc recurrent_bias_tensor_desc = recurrent_bias->GetTensorDesc();
recurrent_bias_tensor_desc.BroadcastTo(half_bias_dimensions);
std::array<DML_TENSOR_DESC, 2> concat_input_tensor_descs = {
bias_tensor_desc.GetDMLTensorDesc(),
recurrent_bias_tensor_desc.GetDMLTensorDesc()};
auto checked_six_times_hidden_size = base::MakeCheckedNum(hidden_size) * 6;
if (!checked_six_times_hidden_size.IsValid()) {
return CreateUnexpectedError(
mojom::Error::Code::kUnknownError,
base::StringPrintf("The hidden size is too large for %s operator.",
OpTagToString(op_tag).c_str()),
label);
}
std::vector<uint32_t> concatenated_bias_dimensions = {
1, 1, num_directions, checked_six_times_hidden_size.ValueOrDie()};
concatenated_bias_tensor_desc = TensorDesc(
GetTensorDataType(data_type), std::move(concatenated_bias_dimensions));
DML_JOIN_OPERATOR_DESC concat_operator_desc{
.InputCount = concat_input_tensor_descs.size(),
.InputTensors = concat_input_tensor_descs.data(),
.OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(),
.Axis = 3};
std::array<const NodeOutput*, 2> bias_outputs = {bias, recurrent_bias};
const GraphNode* concat_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_operator_desc, bias_outputs, label);
const NodeOutput* concatenated_bias = graph_builder.CreateNodeOutput(
concat_node, concatenated_bias_tensor_desc.value(), 0);
inputs.push_back(concatenated_bias);
}
std::optional<TensorDesc> initial_hidden_state_tensor_desc;
if (initial_hidden_state_operand_id.has_value()) {
if constexpr (std::is_same_v<GruType, mojom::GruPtr>) {
CHECK(context_properties.data_type_limits.gru_input.Supports(
GetOperand(operands, *initial_hidden_state_operand_id).descriptor));
} else {
CHECK(context_properties.data_type_limits.gru_cell_input.Supports(
GetOperand(operands, *initial_hidden_state_operand_id).descriptor));
}
const NodeOutput* initial_hidden_state = GetNodeOutputForOperand(
id_to_node_output_map, initial_hidden_state_operand_id.value());
initial_hidden_state =
AppendIdentityToConstantOperand(graph_builder, initial_hidden_state);
initial_hidden_state_tensor_desc = initial_hidden_state->GetTensorDesc();
initial_hidden_state_tensor_desc->EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
inputs.push_back(initial_hidden_state);
} else {
inputs.push_back(nullptr);
}
inputs.push_back(nullptr);
std::vector<OperandId> output_ids;
OperandId output_hidden_state_id;
if constexpr (std::is_same<GruType, mojom::GruPtr>::value) {
output_ids = gru->output_operand_ids;
output_hidden_state_id = output_ids[0];
} else {
output_hidden_state_id = gru->output_operand_id;
}
TensorDesc output_hidden_state_tensor_desc =
CreateOutputTensorDesc(operands, output_hidden_state_id);
output_hidden_state_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
std::optional<OperandId> output_sequence_id;
std::optional<TensorDesc> output_sequence_tensor_desc;
if (return_sequence) {
CHECK_EQ(output_ids.size(), 2u);
output_sequence_id = output_ids[1];
output_sequence_tensor_desc =
CreateOutputTensorDesc(operands, output_sequence_id.value());
}
if (gru->layout != mojom::GruWeightLayout::kZrn) {
return CreateUnexpectedError(
mojom::Error::Code::kNotSupportedError,
"The gru weight layout (rzn) is not supported.", label);
}
const size_t number_of_activations =
direction == mojom::RecurrentNetworkDirection::kBoth
? gru->activations.size() * 2
: gru->activations.size();
base::FixedArray<ActivationOperatorDesc> activation_operator_descs(
number_of_activations);
for (size_t i = 0; i < gru->activations.size(); ++i) {
activation_operator_descs[i] =
CreateOperatorDescForActivation(gru->activations[i]);
if (direction == mojom::RecurrentNetworkDirection::kBoth) {
activation_operator_descs[gru->activations.size() + i] =
activation_operator_descs[i];
}
}
base::FixedArray<DML_OPERATOR_DESC> activation_dml_descs(
activation_operator_descs.size());
std::ranges::transform(
activation_operator_descs, std::begin(activation_dml_descs),
[](const auto& activation_operator_desc) {
return activation_operator_desc.GetActivationDmlDesc();
});
DML_GRU_OPERATOR_DESC gru_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.WeightTensor = &weight_tensor_desc.GetDMLTensorDesc(),
.RecurrenceTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = GetOptionalDmlTensorDescPtr(concatenated_bias_tensor_desc),
.HiddenInitTensor =
GetOptionalDmlTensorDescPtr(initial_hidden_state_tensor_desc),
.SequenceLengthsTensor = nullptr,
.OutputSequenceTensor =
GetOptionalDmlTensorDescPtr(output_sequence_tensor_desc),
.OutputSingleTensor = &output_hidden_state_tensor_desc.GetDMLTensorDesc(),
.ActivationDescCount = static_cast<uint32_t>(activation_dml_descs.size()),
.ActivationDescs = activation_dml_descs.data(),
.Direction = MojoRecurrentNetworkDirectionToDml(direction),
.LinearBeforeReset = gru->reset_after};
const GraphNode* gru_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GRU, &gru_desc, inputs, label);
const NodeOutput* output_hidden_state = graph_builder.CreateNodeOutput(
gru_node, output_hidden_state_tensor_desc, 1);
CHECK(id_to_node_output_map
.try_emplace(output_hidden_state_id, output_hidden_state)
.second);
if (return_sequence) {
const NodeOutput* output_sequence = graph_builder.CreateNodeOutput(
gru_node, output_sequence_tensor_desc.value(), 0);
CHECK(id_to_node_output_map
.try_emplace(output_sequence_id.value(), output_sequence)
.second);
}
return base::ok();
}
void CreateOperatorNodeForHardSigmoid(
const std::vector<OperandPtr>& operands,
const mojom::HardSigmoidPtr& hard_sigmoid,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, hard_sigmoid->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const OperandId output_id = hard_sigmoid->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hard_sigmoid_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = hard_sigmoid->alpha,
.Beta = hard_sigmoid->beta};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* hard_sigmoid_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_HARD_SIGMOID, &hard_sigmoid_desc, inputs,
hard_sigmoid->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
hard_sigmoid_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForHardSwish(Adapter* adapter,
const std::vector<OperandPtr>& operands,
const mojom::HardSwishPtr& hard_swish,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, hard_swish->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const OperandId output_id = hard_swish->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const float scale = 1.0 / 6.0;
const float bias = 0.5;
const std::string& label = hard_swish->label;
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_6_2)) {
std::array<const NodeOutput*, 1> inputs = {input};
DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC hard_swish_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = scale,
.Beta = bias};
const GraphNode* hard_swish_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_HARD_SWISH, &hard_swish_desc, inputs, label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(hard_swish_node, output_tensor_desc);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return;
}
DML_SCALE_BIAS scale_bias = {.Scale = scale, .Bias = bias};
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clamp_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ScaleBias = &scale_bias,
.Min = 0,
.Max = 1};
std::array<const NodeOutput*, 1> clamp_inputs = {input};
const GraphNode* clamp_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_CLIP, &clamp_operator_desc, clamp_inputs,
label);
const NodeOutput* clamp_output =
graph_builder.CreateNodeOutput(clamp_node, output_tensor_desc, 0);
const auto& clamp_output_tensor_desc = clamp_output->GetTensorDesc();
std::array<const NodeOutput*, 2> mul_inputs = {input, clamp_output};
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC binary_mul_desc{
.ATensor = &input_tensor_desc.GetDMLTensorDesc(),
.BTensor = &clamp_output_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
const GraphNode* binary_mul_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &binary_mul_desc, mul_inputs, label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(binary_mul_node, output_tensor_desc);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
template <typename NormalizationPtr>
requires(std::is_same_v<NormalizationPtr, mojom::InstanceNormalizationPtr> ||
std::is_same_v<NormalizationPtr, mojom::LayerNormalizationPtr>)
base::expected<void, mojom::ErrorPtr>
CreateOperatorNodeForMeanVarianceNormalization(
Adapter* adapter,
const ContextProperties& context_properties,
const NormalizationPtr& normalization,
const Operation* operation,
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map,
base::span<const uint32_t> mean_variance_axes,
base::span<const uint32_t> scale_bias_broadcast_axes,
mojom::Operation::Tag op) {
const auto& operands = graph_info->operands;
OperandId input_id = normalization->input_operand_id;
const Operand& input_operand = GetOperand(operands, input_id);
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, input_id);
const auto& input_tensor_desc = input->GetTensorDesc();
size_t input_rank = input_tensor_desc.GetDimensions().size();
OperandId output_id = normalization->output_operand_id;
const Operand& output_operand = GetOperand(operands, output_id);
OperandDataType output_data_type = output_operand.descriptor.data_type();
if constexpr (std::is_same_v<NormalizationPtr,
mojom::InstanceNormalizationPtr>) {
CHECK(context_properties.data_type_limits.instance_normalization_input
.Supports(input_operand.descriptor));
CHECK(context_properties.data_type_limits.instance_normalization_input
.data_types.Has(output_data_type));
} else {
CHECK(
context_properties.data_type_limits.layer_normalization_input.Supports(
input_operand.descriptor));
CHECK(context_properties.data_type_limits.layer_normalization_input
.data_types.Has(output_data_type));
}
const TensorDesc output_tensor_desc(GetTensorDataType(output_data_type),
output_operand.descriptor.shape());
const NodeOutput* scale = GetOptionalNodeOutputForOperand(
id_to_node_output_map, normalization->scale_operand_id);
const NodeOutput* bias = GetOptionalNodeOutputForOperand(
id_to_node_output_map, normalization->bias_operand_id);
if ((scale && !bias) || (!scale && bias)) {
if (!scale) {
OperandId scale_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, output_data_type,
scale_bias_broadcast_axes.size(),
1.0);
CreateConstantNode(adapter, scale_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
scale = GetNodeOutputForOperand(id_to_node_output_map, scale_operand_id);
}
if (!bias) {
OperandId bias_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, output_data_type,
scale_bias_broadcast_axes.size(),
0);
CreateConstantNode(adapter, bias_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
bias = GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
}
}
std::string label = normalization->label;
if (!base::MakeCheckedNum(mean_variance_axes.size()).IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
OpTagToString(op) + ": The axes rank is too large.", label));
}
std::vector<const NodeOutput*> inputs = {input};
std::optional<TensorDesc> scale_tensor_desc;
std::optional<TensorDesc> bias_tensor_desc;
if (scale) {
inputs.push_back(scale);
scale_tensor_desc = scale->GetTensorDesc();
scale_tensor_desc->MakeBroadcastCompatible(input_rank,
scale_bias_broadcast_axes);
}
if (bias) {
inputs.push_back(bias);
bias_tensor_desc = bias->GetTensorDesc();
bias_tensor_desc->MakeBroadcastCompatible(input_rank,
scale_bias_broadcast_axes);
}
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
std::string_view default_label;
if (label.empty()) {
if constexpr (std::is_same_v<NormalizationPtr,
mojom::InstanceNormalizationPtr>) {
default_label = "instance_normalization";
} else {
default_label = "layer_normalization";
}
}
label = GetFusedOperatorLabel(label, default_label,
*fusible_activation.value());
}
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC
normalization_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.ScaleTensor = GetOptionalDmlTensorDescPtr(scale_tensor_desc),
.BiasTensor = GetOptionalDmlTensorDescPtr(bias_tensor_desc),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.AxisCount = base::checked_cast<uint32_t>(mean_variance_axes.size()),
.Axes = mean_variance_axes.data(),
.NormalizeVariance = true,
.Epsilon = normalization->epsilon,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
const GraphNode* normalization_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &normalization_operator_desc,
inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
normalization_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
void CreateOperatorNodeForLeakyRelu(const std::vector<OperandPtr>& operands,
const mojom::LeakyReluPtr& leaky_relu,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, leaky_relu->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
OperandId output_id = leaky_relu->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leaky_relu_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = leaky_relu->alpha};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* leaky_relu_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_LEAKY_RELU, &leaky_relu_desc, inputs,
leaky_relu->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
leaky_relu_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForLinear(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::LinearPtr& linear,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, linear->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.linear_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
OperandId output_id = linear->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_ACTIVATION_LINEAR_OPERATOR_DESC linear_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Alpha = linear->alpha,
.Beta = linear->beta};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* linear_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_LINEAR, &linear_desc, inputs, linear->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
linear_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
template <typename LstmType>
requires(std::is_same_v<LstmType, mojom::Lstm> ||
std::is_same_v<LstmType, mojom::LstmCell>)
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForLstm(
Adapter* adapter,
const ContextProperties& context_properties,
const LstmType& lstm,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map) {
const std::string& label = lstm.label;
const auto& operands = graph_info->operands;
mojom::Operation::Tag op_tag;
std::optional<OperandId> initial_hidden_state_operand_id;
std::optional<OperandId> initial_cell_state_operand_id;
bool return_sequence;
mojom::RecurrentNetworkDirection direction;
if constexpr (std::is_same_v<LstmType, mojom::Lstm>) {
CHECK(context_properties.data_type_limits.lstm_input.SupportsAll(
{GetOperand(operands, lstm.input_operand_id).descriptor,
GetOperand(operands, lstm.weight_operand_id).descriptor,
GetOperand(operands, lstm.recurrent_weight_operand_id).descriptor}));
op_tag = mojom::Operation::Tag::kLstm;
initial_hidden_state_operand_id = lstm.initial_hidden_state_operand_id;
initial_cell_state_operand_id = lstm.initial_cell_state_operand_id;
return_sequence = lstm.return_sequence;
direction = lstm.direction;
} else {
CHECK(context_properties.data_type_limits.lstm_cell_input.SupportsAll(
{GetOperand(operands, lstm.input_operand_id).descriptor,
GetOperand(operands, lstm.weight_operand_id).descriptor,
GetOperand(operands, lstm.recurrent_weight_operand_id).descriptor}));
op_tag = mojom::Operation::Tag::kLstmCell;
initial_hidden_state_operand_id = lstm.hidden_state_operand_id;
initial_cell_state_operand_id = lstm.cell_state_operand_id;
return_sequence = false;
direction = mojom::RecurrentNetworkDirection::kForward;
}
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, lstm.input_operand_id);
input = AppendIdentityToConstantOperand(graph_builder, input);
TensorDesc input_tensor_desc = input->GetTensorDesc();
const DML_TENSOR_DATA_TYPE input_dml_data_type =
input_tensor_desc.GetDataType();
const OperandDataType input_data_type =
DmlDataTypeToOperand(input_dml_data_type);
input_tensor_desc.EnsureMinimumRank(4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* weight =
GetNodeOutputForOperand(id_to_node_output_map, lstm.weight_operand_id);
weight = AppendIdentityToConstantOperand(graph_builder, weight);
TensorDesc weight_tensor_desc = weight->GetTensorDesc();
weight_tensor_desc.EnsureMinimumRank(4,
TensorDesc::Alignment::kTrailing);
const NodeOutput* recurrent_weight = GetNodeOutputForOperand(
id_to_node_output_map, lstm.recurrent_weight_operand_id);
recurrent_weight =
AppendIdentityToConstantOperand(graph_builder, recurrent_weight);
TensorDesc recurrent_weight_tensor_desc = recurrent_weight->GetTensorDesc();
recurrent_weight_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
const uint32_t direction_count =
direction == mojom::RecurrentNetworkDirection::kBoth ? 2 : 1;
const NodeOutput* weight_iofg = weight;
const NodeOutput* recurrent_weight_iofg = recurrent_weight;
if (lstm.layout == mojom::LstmWeightLayout::kIfgo) {
const uint32_t input_size = input_tensor_desc.GetDimensions().at(3);
std::vector<uint32_t> split_weight_output_dims = {
1, direction_count, lstm.hidden_size, input_size};
TensorDesc split_weight_output_tensor_desc(
input_dml_data_type, std::move(split_weight_output_dims));
std::array<DML_TENSOR_DESC, 4> split_weight_tensor_descs_dml;
split_weight_tensor_descs_dml.fill(
split_weight_output_tensor_desc.GetDMLTensorDesc());
DML_SPLIT_OPERATOR_DESC split_desc{
.InputTensor = &weight_tensor_desc.GetDMLTensorDesc(),
.OutputCount = 4,
.OutputTensors = split_weight_tensor_descs_dml.data(),
.Axis = 2};
std::array<const NodeOutput*, 1> split_weight_inputs = {weight};
const GraphNode* split_weight_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SPLIT, &split_desc, split_weight_inputs,
label + "_split_weight_ifgo");
const NodeOutput* split_weight_output_i = graph_builder.CreateNodeOutput(
split_weight_node, split_weight_output_tensor_desc, 0);
const NodeOutput* split_weight_output_f = graph_builder.CreateNodeOutput(
split_weight_node, split_weight_output_tensor_desc, 1);
const NodeOutput* split_weight_output_g = graph_builder.CreateNodeOutput(
split_weight_node, split_weight_output_tensor_desc, 2);
const NodeOutput* split_weight_output_o = graph_builder.CreateNodeOutput(
split_weight_node, split_weight_output_tensor_desc, 3);
DML_JOIN_OPERATOR_DESC concat_desc{
.InputCount = 4,
.InputTensors = split_weight_tensor_descs_dml.data(),
.OutputTensor = &weight_tensor_desc.GetDMLTensorDesc(),
.Axis = 2};
std::array<const NodeOutput*, 4> concat_weight_inputs = {
split_weight_output_i, split_weight_output_o, split_weight_output_f,
split_weight_output_g};
const GraphNode* concat_weight_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_desc, concat_weight_inputs,
label + "_concat_weight_iofg");
weight_iofg =
graph_builder.CreateNodeOutput(concat_weight_node, weight_tensor_desc);
std::vector<uint32_t> split_recurrent_weight_output_dims = {
1, direction_count, lstm.hidden_size, lstm.hidden_size};
TensorDesc split_recurrent_weight_output_tensor_desc(
input_dml_data_type, std::move(split_recurrent_weight_output_dims));
std::array<DML_TENSOR_DESC, 4> split_recurrent_weight_tensor_descs_dml;
split_recurrent_weight_tensor_descs_dml.fill(
split_recurrent_weight_output_tensor_desc.GetDMLTensorDesc());
split_desc.InputTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc();
split_desc.OutputTensors = split_recurrent_weight_tensor_descs_dml.data();
std::array<const NodeOutput*, 1> split_recurrent_weight_inputs = {
recurrent_weight};
const GraphNode* split_recurrent_weight_node =
graph_builder.CreateOperatorNode(
DML_OPERATOR_SPLIT, &split_desc, split_recurrent_weight_inputs,
label + "_split_recurrent_weight_ifgo");
const NodeOutput* split_recurrent_weight_output_i =
graph_builder.CreateNodeOutput(
split_recurrent_weight_node,
split_recurrent_weight_output_tensor_desc, 0);
const NodeOutput* split_recurrent_weight_output_f =
graph_builder.CreateNodeOutput(
split_recurrent_weight_node,
split_recurrent_weight_output_tensor_desc, 1);
const NodeOutput* split_recurrent_weight_output_g =
graph_builder.CreateNodeOutput(
split_recurrent_weight_node,
split_recurrent_weight_output_tensor_desc, 2);
const NodeOutput* split_recurrent_weight_output_o =
graph_builder.CreateNodeOutput(
split_recurrent_weight_node,
split_recurrent_weight_output_tensor_desc, 3);
concat_desc.InputTensors = split_recurrent_weight_tensor_descs_dml.data();
concat_desc.OutputTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc();
std::array<const NodeOutput*, 4> concat_recurrent_weight_inputs = {
split_recurrent_weight_output_i, split_recurrent_weight_output_o,
split_recurrent_weight_output_f, split_recurrent_weight_output_g};
const GraphNode* concat_recurrent_weight_node =
graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_desc, concat_recurrent_weight_inputs,
label + "_concat_recurrent_weight_iofg");
recurrent_weight_iofg = graph_builder.CreateNodeOutput(
concat_recurrent_weight_node, recurrent_weight_tensor_desc);
}
const std::vector<OperandId>& output_ids = lstm.output_operand_ids;
const size_t output_count = output_ids.size();
CHECK_GE(output_count, 2u);
const OperandId output_hidden_state_id = output_ids[0];
const Operand& output_hidden_state_operand =
GetOperand(operands, output_hidden_state_id);
TensorDesc output_hidden_state_tensor_desc(
input_dml_data_type, output_hidden_state_operand.descriptor.shape());
output_hidden_state_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
const OperandId output_cell_state_id = output_ids[1];
TensorDesc output_cell_state_tensor_desc =
CreateOutputTensorDesc(operands, output_cell_state_id);
output_cell_state_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
std::optional<OperandId> output_sequence_id;
std::optional<TensorDesc> output_sequence_tensor_desc;
if (return_sequence) {
CHECK_EQ(output_count, 3u);
output_sequence_id = output_ids[2];
output_sequence_tensor_desc =
CreateOutputTensorDesc(operands, output_sequence_id.value());
}
const NodeOutput* bias = GetOptionalNodeOutputForOperand(
id_to_node_output_map, lstm.bias_operand_id);
const NodeOutput* recurrent_bias = GetOptionalNodeOutputForOperand(
id_to_node_output_map, lstm.recurrent_bias_operand_id);
if ((bias && !recurrent_bias) || (!bias && recurrent_bias)) {
OperandId bias_operand_id = BuildConstantOperandForFloatValue(
context_properties, graph_info, constant_operands, input_data_type,
1, 0);
CreateConstantNode(adapter, bias_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
if (!bias) {
bias = GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
}
if (!recurrent_bias) {
recurrent_bias =
GetNodeOutputForOperand(id_to_node_output_map, bias_operand_id);
}
}
CHECK((bias && recurrent_bias) || (!bias && !recurrent_bias));
std::vector<const NodeOutput*> inputs{input, weight_iofg,
recurrent_weight_iofg};
std::optional<TensorDesc> concatenated_bias_tensor_desc;
if (bias && recurrent_bias) {
auto checked_four_times_hidden_size =
base::MakeCheckedNum(lstm.hidden_size) * 4;
CHECK(checked_four_times_hidden_size.IsValid());
const std::array<uint32_t, 4> bias_dimensions = {
1, 1, direction_count, checked_four_times_hidden_size.ValueOrDie()};
TensorDesc bias_tensor_desc = bias->GetTensorDesc();
bias_tensor_desc.BroadcastTo(bias_dimensions);
TensorDesc recurrent_bias_tensor_desc = recurrent_bias->GetTensorDesc();
recurrent_bias_tensor_desc.BroadcastTo(bias_dimensions);
std::array<DML_TENSOR_DESC, 2> bias_dml_tensor_descs = {
bias_tensor_desc.GetDMLTensorDesc(),
recurrent_bias_tensor_desc.GetDMLTensorDesc()};
auto checked_eight_times_hidden_size = checked_four_times_hidden_size * 2;
if (!checked_eight_times_hidden_size.IsValid()) {
return CreateUnexpectedError(
mojom::Error::Code::kUnknownError,
base::StringPrintf("The hidden size is too large for %s operator.",
OpTagToString(op_tag).c_str()),
label);
}
std::vector<uint32_t> concatenated_dimensions = {
1, 1, direction_count, checked_eight_times_hidden_size.ValueOrDie()};
concatenated_bias_tensor_desc =
TensorDesc(input_dml_data_type, std::move(concatenated_dimensions));
DML_JOIN_OPERATOR_DESC concat_operator_desc{
.InputCount = static_cast<uint32_t>(bias_dml_tensor_descs.size()),
.InputTensors = bias_dml_tensor_descs.data(),
.OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(),
.Axis = 3};
std::array<const NodeOutput*, 2> biases = {bias, recurrent_bias};
const GraphNode* concat_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_operator_desc, biases,
label + "_concat_bias_and_recurrent");
const NodeOutput* concatenated_bias = graph_builder.CreateNodeOutput(
concat_node, concatenated_bias_tensor_desc.value(), 0);
const NodeOutput* concatenated_bias_iofg = concatenated_bias;
if (lstm.layout == mojom::LstmWeightLayout::kIfgo) {
std::vector<uint32_t> split_bias_output_dims = {1, 1, direction_count,
lstm.hidden_size};
TensorDesc split_bias_output_tensor_desc(
input_dml_data_type, std::move(split_bias_output_dims));
std::array<DML_TENSOR_DESC, 8> split_bias_tensor_descs_dml;
split_bias_tensor_descs_dml.fill(
split_bias_output_tensor_desc.GetDMLTensorDesc());
DML_SPLIT_OPERATOR_DESC split_desc{
.InputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(),
.OutputCount = 8,
.OutputTensors = split_bias_tensor_descs_dml.data(),
.Axis = 3};
std::array<const NodeOutput*, 1> split_bias_inputs = {concatenated_bias};
const GraphNode* split_bias_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SPLIT, &split_desc, split_bias_inputs,
label + "_split_bias_ifgo");
const NodeOutput* split_bias_output_i = graph_builder.CreateNodeOutput(
split_bias_node, split_bias_output_tensor_desc, 0);
const NodeOutput* split_bias_output_f = graph_builder.CreateNodeOutput(
split_bias_node, split_bias_output_tensor_desc, 1);
const NodeOutput* split_bias_output_g = graph_builder.CreateNodeOutput(
split_bias_node, split_bias_output_tensor_desc, 2);
const NodeOutput* split_bias_output_o = graph_builder.CreateNodeOutput(
split_bias_node, split_bias_output_tensor_desc, 3);
const NodeOutput* split_recurrent_bias_output_i =
graph_builder.CreateNodeOutput(split_bias_node,
split_bias_output_tensor_desc,
4);
const NodeOutput* split_recurrent_bias_output_f =
graph_builder.CreateNodeOutput(split_bias_node,
split_bias_output_tensor_desc,
5);
const NodeOutput* split_recurrent_bias_output_g =
graph_builder.CreateNodeOutput(split_bias_node,
split_bias_output_tensor_desc,
6);
const NodeOutput* split_recurrent_bias_output_o =
graph_builder.CreateNodeOutput(split_bias_node,
split_bias_output_tensor_desc,
7);
DML_JOIN_OPERATOR_DESC concat_bias_desc{
.InputCount = 8,
.InputTensors = split_bias_tensor_descs_dml.data(),
.OutputTensor = &concatenated_bias_tensor_desc->GetDMLTensorDesc(),
.Axis = 3};
std::array<const NodeOutput*, 8> concat_bias_inputs = {
split_bias_output_i, split_bias_output_o,
split_bias_output_f, split_bias_output_g,
split_recurrent_bias_output_i, split_recurrent_bias_output_o,
split_recurrent_bias_output_f, split_recurrent_bias_output_g};
const GraphNode* concat_bias_iofg_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_JOIN, &concat_bias_desc, concat_bias_inputs,
label + "_concat_bias_iofg");
concatenated_bias_iofg = graph_builder.CreateNodeOutput(
concat_bias_iofg_node, concatenated_bias_tensor_desc.value());
}
inputs.push_back(concatenated_bias_iofg);
} else {
inputs.push_back(nullptr);
}
std::optional<TensorDesc> initial_hidden_state_tensor_desc;
if (initial_hidden_state_operand_id.has_value()) {
if constexpr (std::is_same_v<LstmType, mojom::Lstm>) {
CHECK(context_properties.data_type_limits.lstm_input.Supports(
GetOperand(operands, *initial_hidden_state_operand_id).descriptor));
} else {
CHECK(context_properties.data_type_limits.lstm_cell_input.Supports(
GetOperand(operands, *initial_hidden_state_operand_id).descriptor));
}
const NodeOutput* initial_hidden_state = GetNodeOutputForOperand(
id_to_node_output_map, initial_hidden_state_operand_id.value());
initial_hidden_state =
AppendIdentityToConstantOperand(graph_builder, initial_hidden_state);
inputs.push_back(initial_hidden_state);
initial_hidden_state_tensor_desc = initial_hidden_state->GetTensorDesc();
initial_hidden_state_tensor_desc->EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
} else {
inputs.push_back(nullptr);
}
std::optional<TensorDesc> initial_cell_state_tensor_desc;
if (initial_cell_state_operand_id.has_value()) {
if constexpr (std::is_same_v<LstmType, mojom::Lstm>) {
CHECK(context_properties.data_type_limits.lstm_input.Supports(
GetOperand(operands, *initial_cell_state_operand_id).descriptor));
} else {
CHECK(context_properties.data_type_limits.lstm_cell_input.Supports(
GetOperand(operands, *initial_cell_state_operand_id).descriptor));
}
const NodeOutput* initial_cell_state = GetNodeOutputForOperand(
id_to_node_output_map, initial_cell_state_operand_id.value());
initial_cell_state =
AppendIdentityToConstantOperand(graph_builder, initial_cell_state);
inputs.push_back(initial_cell_state);
initial_cell_state_tensor_desc = initial_cell_state->GetTensorDesc();
initial_cell_state_tensor_desc->EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
} else {
inputs.push_back(nullptr);
}
inputs.push_back(nullptr);
std::optional<TensorDesc> peephole_weight_tensor_desc;
if (lstm.peephole_weight_operand_id.has_value()) {
if constexpr (std::is_same_v<LstmType, mojom::Lstm>) {
CHECK(context_properties.data_type_limits.lstm_bias.Supports(
GetOperand(operands, *lstm.peephole_weight_operand_id).descriptor));
} else {
CHECK(context_properties.data_type_limits.lstm_cell_bias.Supports(
GetOperand(operands, *lstm.peephole_weight_operand_id).descriptor));
}
const NodeOutput* peephole_weight = GetNodeOutputForOperand(
id_to_node_output_map, lstm.peephole_weight_operand_id.value());
peephole_weight =
AppendIdentityToConstantOperand(graph_builder, peephole_weight);
inputs.push_back(peephole_weight);
peephole_weight_tensor_desc = peephole_weight->GetTensorDesc();
peephole_weight_tensor_desc->EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
}
const size_t number_of_activations =
direction == mojom::RecurrentNetworkDirection::kBoth
? lstm.activations.size() * 2
: lstm.activations.size();
base::FixedArray<ActivationOperatorDesc> activation_operator_descs(
number_of_activations);
for (size_t i = 0; i < lstm.activations.size(); ++i) {
activation_operator_descs[i] =
CreateOperatorDescForActivation(lstm.activations[i]);
if (direction == mojom::RecurrentNetworkDirection::kBoth) {
activation_operator_descs[lstm.activations.size() + i] =
activation_operator_descs[i];
}
}
base::FixedArray<DML_OPERATOR_DESC> activation_dml_descs(
activation_operator_descs.size());
std::ranges::transform(
activation_operator_descs, activation_dml_descs.begin(),
[](const auto& activation_operator_desc) {
return activation_operator_desc.GetActivationDmlDesc();
});
DML_LSTM_OPERATOR_DESC lstm_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.WeightTensor = &weight_tensor_desc.GetDMLTensorDesc(),
.RecurrenceTensor = &recurrent_weight_tensor_desc.GetDMLTensorDesc(),
.BiasTensor = GetOptionalDmlTensorDescPtr(concatenated_bias_tensor_desc),
.HiddenInitTensor =
GetOptionalDmlTensorDescPtr(initial_hidden_state_tensor_desc),
.CellMemInitTensor =
GetOptionalDmlTensorDescPtr(initial_cell_state_tensor_desc),
.SequenceLengthsTensor = nullptr,
.PeepholeTensor =
GetOptionalDmlTensorDescPtr(peephole_weight_tensor_desc),
.OutputSequenceTensor =
GetOptionalDmlTensorDescPtr(output_sequence_tensor_desc),
.OutputSingleTensor = &output_hidden_state_tensor_desc.GetDMLTensorDesc(),
.OutputCellSingleTensor =
&output_cell_state_tensor_desc.GetDMLTensorDesc(),
.ActivationDescCount = static_cast<uint32_t>(activation_dml_descs.size()),
.ActivationDescs = activation_dml_descs.data(),
.Direction = MojoRecurrentNetworkDirectionToDml(direction),
.ClipThreshold = 0,
.UseClipThreshold = FALSE,
.CoupleInputForget = FALSE};
const GraphNode* lstm_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_LSTM, &lstm_desc, inputs, label);
if (return_sequence) {
const NodeOutput* output_sequence = graph_builder.CreateNodeOutput(
lstm_node, output_sequence_tensor_desc.value(), 0);
CHECK(id_to_node_output_map
.try_emplace(output_sequence_id.value(), output_sequence)
.second);
}
const NodeOutput* output_hidden_state = graph_builder.CreateNodeOutput(
lstm_node, output_hidden_state_tensor_desc, 1);
CHECK(id_to_node_output_map
.try_emplace(output_hidden_state_id, output_hidden_state)
.second);
const NodeOutput* output_cell_state = graph_builder.CreateNodeOutput(
lstm_node, output_cell_state_tensor_desc, 2);
CHECK(
id_to_node_output_map.try_emplace(output_cell_state_id, output_cell_state)
.second);
return base::ok();
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForMatmul(
const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const Operation* operation,
const absl::flat_hash_map<const Operation*,
raw_ptr<const Operation, CtnExperimental>>&
operation_to_fusible_standalone_activation_map,
const absl::flat_hash_map<OperandId,
raw_ptr<const Operation, CtnExperimental>>&
output_id_to_fusible_transpose_map,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const auto& matmul = operation->get_matmul();
bool transpose_a = false;
OperandId a_operand_id = matmul->a_operand_id;
std::optional<OperandId> fusible_transpose_input_id =
GetFusibleTransposeInputId(output_id_to_fusible_transpose_map,
a_operand_id);
std::string label;
if (fusible_transpose_input_id) {
std::string_view transpose_a_label =
output_id_to_fusible_transpose_map.at(a_operand_id)
->get_transpose()
->label;
base::StrAppend(
&label,
{transpose_a_label.empty() ? "transpose_a" : transpose_a_label, "+"});
a_operand_id = fusible_transpose_input_id.value();
transpose_a = true;
}
const NodeOutput* input_a_node_output =
GetNodeOutputForOperand(id_to_node_output_map, a_operand_id);
auto input_a_tensor_desc = input_a_node_output->GetTensorDesc();
CHECK(kDmlFloatDataTypes.contains(input_a_tensor_desc.GetDataType()));
bool transpose_b = false;
OperandId b_operand_id = matmul->b_operand_id;
fusible_transpose_input_id = GetFusibleTransposeInputId(
output_id_to_fusible_transpose_map, b_operand_id);
if (fusible_transpose_input_id) {
std::string_view transpose_b_label =
output_id_to_fusible_transpose_map.at(b_operand_id)
->get_transpose()
->label;
base::StrAppend(
&label,
{transpose_b_label.empty() ? "transpose_b" : transpose_b_label, "+"});
b_operand_id = fusible_transpose_input_id.value();
transpose_b = true;
}
const NodeOutput* input_b_node_output =
GetNodeOutputForOperand(id_to_node_output_map, b_operand_id);
auto input_b_tensor_desc = input_b_node_output->GetTensorDesc();
OperandId output_id = matmul->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const auto output_tensor_dims = output_tensor_desc.GetDimensions();
if (output_tensor_dims.size() > 2) {
input_a_tensor_desc.BroadcastTo(output_tensor_dims, 2);
input_b_tensor_desc.BroadcastTo(output_tensor_dims, 2);
}
CHECK(context_properties.data_type_limits.matmul_input.data_types.Has(
DmlDataTypeToOperand(input_a_tensor_desc.GetDataType())));
CHECK_EQ(input_a_tensor_desc.GetDimensions().size(),
input_b_tensor_desc.GetDimensions().size());
CHECK_EQ(input_a_tensor_desc.GetDimensions().size(),
output_tensor_dims.size());
if (!label.empty()) {
base::StrAppend(&label, {matmul->label.empty() ? "matmul" : matmul->label});
} else {
label = matmul->label;
}
auto adjusted_output_tensor_desc = output_tensor_desc;
if (output_tensor_dims.size() > 4) {
if (!input_a_tensor_desc.RightAlignedFlattenTo(4)) {
input_a_node_output = AppendIdentityNode(
graph_builder, input_a_node_output, &input_a_tensor_desc);
input_a_tensor_desc = input_a_node_output->GetTensorDesc();
CHECK(input_a_tensor_desc.RightAlignedFlattenTo(4));
}
if (!input_b_tensor_desc.RightAlignedFlattenTo(4)) {
input_b_node_output = AppendIdentityNode(
graph_builder, input_b_node_output, &input_b_tensor_desc);
input_b_tensor_desc = input_b_node_output->GetTensorDesc();
CHECK(input_b_tensor_desc.RightAlignedFlattenTo(4));
}
CHECK(adjusted_output_tensor_desc.RightAlignedFlattenTo(4));
}
else if (output_tensor_dims.size() < 4) {
input_a_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
input_b_tensor_desc.EnsureMinimumRank(4, TensorDesc::Alignment::kTrailing);
adjusted_output_tensor_desc.EnsureMinimumRank(
4, TensorDesc::Alignment::kTrailing);
}
std::optional<const Operation*> fusible_activation =
GetFusibleActivationFromOperation(
operation_to_fusible_standalone_activation_map, operation);
std::optional<ActivationOperatorDesc> activation_operator_desc;
std::optional<DML_OPERATOR_DESC> activation_dml_desc;
if (fusible_activation) {
activation_operator_desc =
CreateOperatorDescForFusibleActivation(*fusible_activation.value());
activation_dml_desc = activation_operator_desc->GetActivationDmlDesc();
output_id =
GetFusibleActivationOutputId(*fusible_activation.value()).value();
label = GetFusedOperatorLabel(label, "matmul", *fusible_activation.value());
}
DML_GEMM_OPERATOR_DESC matmul_operator_desc{
.ATensor = &input_a_tensor_desc.GetDMLTensorDesc(),
.BTensor = &input_b_tensor_desc.GetDMLTensorDesc(),
.CTensor = nullptr,
.OutputTensor = &adjusted_output_tensor_desc.GetDMLTensorDesc(),
.TransA = transpose_a ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.TransB = transpose_b ? DML_MATRIX_TRANSFORM_TRANSPOSE
: DML_MATRIX_TRANSFORM_NONE,
.Alpha = 1.0f,
.Beta = 0.0f,
.FusedActivation =
activation_dml_desc ? &activation_dml_desc.value() : nullptr,
};
std::array<const NodeOutput*, 2> inputs{input_a_node_output,
input_b_node_output};
const GraphNode* matmul_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_GEMM, &matmul_operator_desc, inputs, label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
matmul_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
const NodeOutput* CreateTransposeNode(GraphBuilderDml& graph_builder,
const NodeOutput* input,
base::span<const uint32_t> permutation) {
CHECK(input);
const TensorDesc& input_tensor_desc = input->GetTensorDesc();
TensorDesc transposed_input_tensor_desc = input_tensor_desc;
transposed_input_tensor_desc.Transpose(permutation);
const NodeOutput* transpose_node =
AppendIdentityNode(graph_builder, input, &transposed_input_tensor_desc);
return transpose_node;
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForSoftmax(
Adapter* adapter,
const std::vector<OperandPtr>& operands,
const mojom::SoftmaxPtr& softmax,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, softmax->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
OperandId output_id = softmax->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
std::array<const NodeOutput*, 1> inputs = {input};
const uint32_t axis = softmax->axis;
const std::string& label = softmax->label;
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_1)) {
std::array<uint32_t, 1> axes = {axis};
DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC softmax1_operator_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.AxisCount = base::checked_cast<uint32_t>(axes.size()),
.Axes = axes.data()};
const GraphNode* softmax_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_SOFTMAX1, &softmax1_operator_desc, inputs,
label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
softmax_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
} else {
const NodeOutput* axis_transposed_to_last_output = nullptr;
const uint32_t input_rank = input_tensor_desc.GetDimensions().size();
std::vector<uint32_t> permutation(input_rank);
std::iota(permutation.begin(), permutation.end(), 0);
if (axis == (input_rank - 1)) {
axis_transposed_to_last_output = input;
} else {
std::vector<uint32_t> transpose_axis_to_last(permutation);
std::swap(transpose_axis_to_last[axis],
transpose_axis_to_last[input_rank - 1]);
axis_transposed_to_last_output =
CreateTransposeNode(graph_builder, input, transpose_axis_to_last);
}
const NodeOutput* reshaped_2d_output = nullptr;
if (axis_transposed_to_last_output->GetTensorDesc()
.GetDimensions()
.size() <= 2) {
reshaped_2d_output = axis_transposed_to_last_output;
} else {
const std::vector<uint32_t>& axis_transposed_to_last_output_dims =
axis_transposed_to_last_output->GetTensorDesc().GetDimensions();
auto reshaped_2d_dim_0 = base::MakeCheckedNum<uint32_t>(1);
for (uint32_t i = 0; i < axis_transposed_to_last_output_dims.size() - 1;
i++) {
reshaped_2d_dim_0 *= axis_transposed_to_last_output_dims[i];
if (!reshaped_2d_dim_0.IsValid<uint32_t>()) {
return CreateUnexpectedError(
mojom::Error::Code::kNotSupportedError,
"For softmax impl: failed to reshape the input to 2-D tensor.",
label);
}
}
std::vector<uint32_t> reshaped_2d_dims = {
reshaped_2d_dim_0.ValueOrDie(),
axis_transposed_to_last_output_dims.back()};
reshaped_2d_output = CreateReshapeNode(
graph_builder, axis_transposed_to_last_output, reshaped_2d_dims);
}
const TensorDesc softmax_2d_output_tensor_desc =
TensorDesc(reshaped_2d_output->GetTensorDesc().GetDataType(),
reshaped_2d_output->GetTensorDesc().GetDimensions());
DML_ACTIVATION_SOFTMAX_OPERATOR_DESC softmax_2d_operator_desc{
.InputTensor = &reshaped_2d_output->GetTensorDesc().GetDMLTensorDesc(),
.OutputTensor = &softmax_2d_output_tensor_desc.GetDMLTensorDesc()};
std::array<const NodeOutput*, 1> softmax_2d_inputs = {reshaped_2d_output};
const GraphNode* softmax_2d_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ACTIVATION_SOFTMAX, &softmax_2d_operator_desc,
softmax_2d_inputs, label);
const NodeOutput* softmax_2d_output = graph_builder.CreateNodeOutput(
softmax_2d_node, softmax_2d_output_tensor_desc);
const NodeOutput* reshaped_nd_output = nullptr;
if (axis_transposed_to_last_output->GetTensorDesc()
.GetDimensions()
.size() <= 2) {
reshaped_nd_output = softmax_2d_output;
} else {
reshaped_nd_output = CreateReshapeNode(
graph_builder, softmax_2d_output,
axis_transposed_to_last_output->GetTensorDesc().GetDimensions());
}
const NodeOutput* last_transposed_to_axis_output = nullptr;
if (axis == (input_rank - 1)) {
last_transposed_to_axis_output = reshaped_nd_output;
} else {
std::vector<uint32_t> transpose_axis_back(permutation);
std::swap(transpose_axis_back[axis], transpose_axis_back[input_rank - 1]);
last_transposed_to_axis_output = CreateTransposeNode(
graph_builder, reshaped_nd_output, transpose_axis_back);
}
CHECK(id_to_node_output_map
.try_emplace(output_id, last_transposed_to_axis_output)
.second);
}
return base::ok();
}
void CreateOperatorNodeForSoftplus(const std::vector<OperandPtr>& operands,
const mojom::SoftplusPtr& softplus,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(id_to_node_output_map,
softplus->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
const OperandId output_id = softplus->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Steepness = 1.0};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* softplus_node =
graph_builder.CreateOperatorNode(DML_OPERATOR_ACTIVATION_SOFTPLUS,
&softplus_desc, inputs, softplus->label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
softplus_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForTile(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::TilePtr& tile,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input =
GetNodeOutputForOperand(id_to_node_output_map, tile->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.tile_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const OperandId output_id = tile->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
base::span<const uint32_t> repetitions = tile->repetitions;
DML_TILE_OPERATOR_DESC tile_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.RepeatsCount = base::checked_cast<uint32_t>(repetitions.size()),
.Repeats = repetitions.data()};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* tile_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_TILE, &tile_desc, inputs, tile->label);
const NodeOutput* node_output =
graph_builder.CreateNodeOutput(tile_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}
void CreateOperatorNodeForTranspose(const ContextProperties& context_properties,
const std::vector<OperandPtr>& operands,
const mojom::TransposePtr& transpose,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, transpose->input_operand_id);
CHECK(context_properties.data_type_limits.transpose_input.data_types.Has(
DmlDataTypeToOperand(input->GetTensorDesc().GetDataType())));
const Operand& operand = GetOperand(operands, transpose->input_operand_id);
if (operand.descriptor.shape().empty()) {
CHECK_EQ(input->GetTensorDesc().GetDimensions().size(), 1u);
CHECK(transpose->permutation.empty());
}
OperandId output_id = transpose->output_operand_id;
const NodeOutput* output = CreateTransposeNode(
graph_builder, input,
operand.descriptor.shape().empty() ? std::vector<uint32_t>{0}
: transpose->permutation);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
base::expected<void, mojom::ErrorPtr> CreateOperatorNodeForTriangular(
const ContextProperties& context_properties,
Adapter* adapter,
const mojom::TriangularPtr& triangular,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, triangular->input_operand_id);
const auto& input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.triangular_input.data_types.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));
const auto& operands = graph_info->operands;
OperandId output_id = triangular->output_operand_id;
auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
CHECK_EQ(input_tensor_desc.GetDimensions().size(),
output_tensor_desc.GetDimensions().size());
const auto& input_dimensions = input_tensor_desc.GetDimensions();
const auto input_rank = input_dimensions.size();
CHECK_GE(input_rank, 2U);
bool upper = triangular->upper;
int32_t diagonal = triangular->diagonal;
const std::string& label = triangular->label;
DML_SCALAR_UNION scalar_union = {};
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_5_1) &&
input_rank <= 4) {
DML_DIAGONAL_MATRIX1_OPERATOR_DESC diagonal_matrix1_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ValueDataType = output_tensor_desc.GetDataType(),
.Value = scalar_union,
.DiagonalFillBegin =
upper ? std::numeric_limits<int32_t>::min() : diagonal + 1,
.DiagonalFillEnd =
upper ? diagonal : std::numeric_limits<int32_t>::max()};
std::array<const NodeOutput*, 1> inputs = {input};
const GraphNode* diagonal_matrix1_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_DIAGONAL_MATRIX1, &diagonal_matrix1_desc, inputs, label);
const NodeOutput* node_output = graph_builder.CreateNodeOutput(
diagonal_matrix1_node, std::move(output_tensor_desc));
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
return base::ok();
}
const Operand& output_operand = GetOperand(operands, output_id);
OperandDataType data_type = output_operand.descriptor.data_type();
const uint32_t height = input_dimensions[input_rank - 2];
const uint32_t width = input_dimensions[input_rank - 1];
uint32_t longest_dimension_length = std::max(height, width);
// 4, 5, 6, \
// 7, 8, 9] \
// 2. Upper = false
if ((diagonal > 0 &&
(base::checked_cast<uint32_t>(diagonal) >= longest_dimension_length) &&
upper) ||
(diagonal < 0 &&
(base::checked_cast<uint32_t>(-diagonal) >= longest_dimension_length) &&
!upper)) {
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC fill_constant_operator_desc{
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.ValueDataType = output_tensor_desc.GetDataType(),
.Value = scalar_union,
};
const GraphNode* fill_constant_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_FILL_VALUE_CONSTANT, &fill_constant_operator_desc, {},
label);
const NodeOutput* constant = graph_builder.CreateNodeOutput(
fill_constant_node, std::move(output_tensor_desc), 0);
auto constant_tensor_desc = constant->GetTensorDesc();
std::array<const NodeOutput*, 2> inputs = {input, constant};
const GraphNode* mul_node =
CreateBinaryOperator<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>(
input_tensor_desc, constant_tensor_desc, output_tensor_desc,
graph_builder, DML_OPERATOR_ELEMENT_WISE_MULTIPLY, inputs, label);
const NodeOutput* output =
graph_builder.CreateNodeOutput(mul_node, output_tensor_desc);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
// 4, 5, 6, \
// 7, 8, 9] \
// 2. Upper = true
if ((diagonal > 0 &&
(base::checked_cast<uint32_t>(diagonal) >= longest_dimension_length) &&
!upper) ||
(diagonal < 0 &&
(base::checked_cast<uint32_t>(-diagonal) >= longest_dimension_length) &&
upper)) {
const Node& input_node = input->GetNode();
const NodeOutput* output = graph_builder.CreateNodeOutput(
&input_node, std::move(output_tensor_desc), input->GetOutputIndex());
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
return base::ok();
}
uint64_t lower_mask = 0;
uint64_t upper_mask = std::numeric_limits<uint64_t>::max();
if (!upper) {
std::swap(lower_mask, upper_mask);
}
OperandDataType webnn_mask_data_type;
DML_TENSOR_DATA_TYPE dml_mask_data_type;
base::HeapArray<uint8_t> buffer;
switch (data_type) {
case OperandDataType::kInt8:
case OperandDataType::kUint8: {
webnn_mask_data_type = OperandDataType::kUint8;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT8;
std::array<uint8_t, 2> values = {static_cast<uint8_t>(lower_mask),
static_cast<uint8_t>(upper_mask)};
buffer = base::HeapArray<uint8_t>::CopiedFrom(base::as_byte_span(values));
break;
}
case OperandDataType::kFloat16: {
webnn_mask_data_type = OperandDataType::kFloat16;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT16;
std::array<uint16_t, 2> values = {static_cast<uint16_t>(lower_mask),
static_cast<uint16_t>(upper_mask)};
buffer = base::HeapArray<uint8_t>::CopiedFrom(base::as_byte_span(values));
break;
}
case OperandDataType::kFloat32:
case OperandDataType::kInt32:
case OperandDataType::kUint32: {
webnn_mask_data_type = OperandDataType::kUint32;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT32;
std::array<uint32_t, 2> values = {static_cast<uint32_t>(lower_mask),
static_cast<uint32_t>(upper_mask)};
buffer = base::HeapArray<uint8_t>::CopiedFrom(base::as_byte_span(values));
break;
}
case OperandDataType::kInt64:
case OperandDataType::kUint64: {
webnn_mask_data_type = OperandDataType::kUint64;
dml_mask_data_type = DML_TENSOR_DATA_TYPE_UINT64;
std::array<uint64_t, 2> values = {static_cast<uint64_t>(lower_mask),
static_cast<uint64_t>(upper_mask)};
buffer = base::HeapArray<uint8_t>::CopiedFrom(base::as_byte_span(values));
break;
}
default:
NOTREACHED() << "Unsupported data type.";
}
auto descriptor = *OperandDescriptor::Create(
context_properties, webnn_mask_data_type,
std::array<uint32_t, 3>{1, 2, 1}, "triangular");
auto constant_operand = Operand::New(Operand::Kind::kConstant, descriptor,
std::nullopt);
OperandId constant_operand_id(graph_info->operands.size());
graph_info->operands.push_back(std::move(constant_operand));
CHECK(constant_operands
.try_emplace(constant_operand_id,
std::make_unique<WebNNConstantOperand>(
descriptor, std::move(buffer)))
.second);
CreateConstantNode(adapter, constant_operand_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
const NodeOutput* constant =
GetNodeOutputForOperand(id_to_node_output_map, constant_operand_id);
auto constant_tensor_desc = constant->GetTensorDesc();
const auto mask_height = height;
const auto checked_mask_width =
(base::MakeCheckedNum<uint32_t>(longest_dimension_length) +
std::min(base::checked_cast<uint32_t>(std::abs(diagonal)),
longest_dimension_length)) *
2;
if (!checked_mask_width.IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"For triangular impl: the mask width is too large.", label));
}
const uint32_t mask_width = checked_mask_width.ValueOrDie();
std::vector<uint32_t> expand_constant_dims = {mask_height, 2, mask_width};
if (constant_tensor_desc.GetDimensions() != expand_constant_dims) {
constant_tensor_desc.BroadcastTo(expand_constant_dims);
}
const auto expand_constant_tensor_desc = TensorDesc(
constant_tensor_desc.GetDataType(), std::move(expand_constant_dims));
const GraphNode* expand_constant_node =
CreateUnaryOperator<DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC,
DML_OPERATOR_ELEMENT_WISE_IDENTITY>(
constant_tensor_desc, expand_constant_tensor_desc, constant,
graph_builder, label);
const auto* expand_constant_output = graph_builder.CreateNodeOutput(
expand_constant_node, std::move(expand_constant_tensor_desc));
auto expand_constant_output_tensor_desc =
expand_constant_output->GetTensorDesc();
const auto checked_slice_input_width =
base::MakeCheckedNum<uint32_t>(mask_width) * 2;
if (!checked_slice_input_width.IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"For triangular impl: the input width for slice is too large.", label));
}
const uint32_t slice_input_width = checked_slice_input_width.ValueOrDie();
std::vector<uint32_t> slice_input_dims = {mask_height, slice_input_width};
const auto checked_slice_input_stride = checked_slice_input_width - 1;
if (!checked_slice_input_stride.IsValid<uint32_t>()) {
return base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"For triangular impl: the input stride for slice is invalid.", label));
}
const uint32_t slice_input_stride = checked_slice_input_stride.ValueOrDie();
std::vector<uint32_t> slice_input_strides = {slice_input_stride, 1};
auto slice_input_tensor_desc =
TensorDesc(expand_constant_output_tensor_desc.GetDataType(),
expand_constant_output_tensor_desc.GetFlags(),
std::move(slice_input_dims), std::move(slice_input_strides));
slice_input_tensor_desc.SetTotalTensorSizeInBytes(
expand_constant_output_tensor_desc.GetTotalTensorSizeInBytes());
std::vector<uint32_t> slice_output_dims = {height, width};
auto slice_output_tensor_desc = TensorDesc(
expand_constant_tensor_desc.GetDataType(), std::move(slice_output_dims));
std::array<uint32_t, 2> sizes = {height, width};
std::array<uint32_t, 2> offset =
upper ? std::array<uint32_t, 2>{0, mask_width - diagonal}
: std::array<uint32_t, 2>{0, mask_width - diagonal - 1};
std::array<uint32_t, 2> strides = {1, 1};
DML_SLICE_OPERATOR_DESC slice_operator_desc{
.InputTensor = &slice_input_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &slice_output_tensor_desc.GetDMLTensorDesc(),
.DimensionCount = 2,
.Offsets = offset.data(),
.Sizes = sizes.data(),
.Strides = strides.data(),
};
std::array<const NodeOutput*, 1> input_for_slice = {expand_constant_output};
const GraphNode* slice_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SLICE, &slice_operator_desc, input_for_slice, label);
const auto* slice_output = graph_builder.CreateNodeOutput(
slice_node, std::move(slice_output_tensor_desc));
slice_output_tensor_desc = slice_output->GetTensorDesc();
if (slice_output_tensor_desc.GetDimensions() != input_dimensions) {
slice_output_tensor_desc.BroadcastTo(input_dimensions);
}
TensorDesc bit_and_operator_input_tensor_desc =
TensorDesc(dml_mask_data_type, input_tensor_desc.GetFlags(),
input_tensor_desc.GetDimensions());
TensorDesc bit_and_operator_mask_tensor_desc =
TensorDesc(dml_mask_data_type, slice_output_tensor_desc.GetFlags(),
slice_output_tensor_desc.GetDimensions(),
slice_output_tensor_desc.GetStrides());
TensorDesc bit_and_operator_output_tensor_desc =
TensorDesc(dml_mask_data_type, output_tensor_desc.GetFlags(),
output_tensor_desc.GetDimensions());
DML_ELEMENT_WISE_BIT_AND_OPERATOR_DESC bit_and_operator_desc{
.ATensor = &bit_and_operator_input_tensor_desc.GetDMLTensorDesc(),
.BTensor = &bit_and_operator_mask_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &bit_and_operator_output_tensor_desc.GetDMLTensorDesc()};
std::array<const NodeOutput*, 2> inputs{input, slice_output};
const GraphNode* bit_and_operator_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_BIT_AND, &bit_and_operator_desc, inputs, label);
const NodeOutput* bit_and_operator_output =
graph_builder.CreateNodeOutput(bit_and_operator_node, output_tensor_desc);
CHECK(id_to_node_output_map.try_emplace(output_id, bit_and_operator_output)
.second);
return base::ok();
}
void CreateOperatorNodeForWhere(const std::vector<OperandPtr>& operands,
const mojom::WherePtr& where,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* condition = GetNodeOutputForOperand(
id_to_node_output_map, where->condition_operand_id);
auto condition_tensor_desc = condition->GetTensorDesc();
const NodeOutput* true_value = GetNodeOutputForOperand(
id_to_node_output_map, where->true_value_operand_id);
auto true_value_tensor_desc = true_value->GetTensorDesc();
const NodeOutput* false_value = GetNodeOutputForOperand(
id_to_node_output_map, where->false_value_operand_id);
auto false_value_tensor_desc = false_value->GetTensorDesc();
OperandId output_id = where->output_operand_id;
const auto output_tensor_desc = CreateOutputTensorDesc(operands, output_id);
const auto output_tensor_dims = output_tensor_desc.GetDimensions();
if (condition_tensor_desc.GetDimensions() != output_tensor_dims) {
condition_tensor_desc.BroadcastTo(output_tensor_dims);
}
if (true_value_tensor_desc.GetDimensions() != output_tensor_dims) {
true_value_tensor_desc.BroadcastTo(output_tensor_dims);
}
if (false_value_tensor_desc.GetDimensions() != output_tensor_dims) {
false_value_tensor_desc.BroadcastTo(output_tensor_dims);
}
DML_ELEMENT_WISE_IF_OPERATOR_DESC where_operator_desc{
.ConditionTensor = &condition_tensor_desc.GetDMLTensorDesc(),
.ATensor = &true_value_tensor_desc.GetDMLTensorDesc(),
.BTensor = &false_value_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc()};
std::array<const NodeOutput*, 3> inputs{condition, true_value, false_value};
const GraphNode* where_node = graph_builder.CreateOperatorNode(
DML_OPERATOR_ELEMENT_WISE_IF, &where_operator_desc, inputs, where->label);
const NodeOutput* output = graph_builder.CreateNodeOutput(
where_node, std::move(output_tensor_desc), 0);
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}
void HandleGraphCreationFailure(
const std::string& error_message,
WebNNContextImpl::CreateGraphImplCallback callback,
ContextImplDml* context,
HRESULT hr) {
std::move(callback).Run(base::unexpected(
CreateError(mojom::Error::Code::kUnknownError, error_message)));
context->HandleContextLostOrCrash(error_message, hr);
}
bool IsDispatchBindingValid(
const base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>>&
named_tensors,
const base::flat_map<std::string, base::WeakPtr<const WebNNTensorImpl>>&
prev_named_tensors) {
return std::ranges::equal(
named_tensors, prev_named_tensors,
[](const auto& pair, const auto& previous_pair) {
const auto& [name, tensor] = pair;
const auto& [prev_name, prev_tensor] = previous_pair;
return name == prev_name && tensor == prev_tensor.get();
});
}
}
class GraphImplDml::PersistentResource final
: public base::RefCountedThreadSafe<PersistentResource> {
public:
static scoped_refptr<PersistentResource> Create(
uint64_t persistent_buffer_byte_length,
Microsoft::WRL::ComPtr<ID3D12Resource> persistent_buffer);
PersistentResource(const PersistentResource&) = delete;
PersistentResource& operator=(const PersistentResource&) = delete;
DML_BINDING_DESC persistent_buffer_binding_desc() const {
return persistent_buffer_binding_desc_;
}
private:
friend class base::RefCountedThreadSafe<PersistentResource>;
PersistentResource(uint64_t persistent_buffer_byte_length,
Microsoft::WRL::ComPtr<ID3D12Resource> persistent_buffer);
~PersistentResource();
Microsoft::WRL::ComPtr<ID3D12Resource> persistent_buffer_;
DML_BUFFER_BINDING persistent_buffer_binding_;
DML_BINDING_DESC persistent_buffer_binding_desc_;
};
struct GraphImplDml::GraphResources {
GraphResources(Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap,
uint64_t temporary_buffer_byte_length,
Microsoft::WRL::ComPtr<ID3D12Resource> temporary_resource);
~GraphResources();
GraphResources(const GraphResources&) = delete;
GraphResources& operator=(const GraphResources&) = delete;
GraphResources(GraphResources&&) = delete;
GraphResources& operator=(GraphResources&&) = delete;
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> descriptor_heap;
Microsoft::WRL::ComPtr<ID3D12Resource> temporary_buffer;
std::optional<DML_BUFFER_BINDING> temporary_buffer_binding;
std::optional<DML_BINDING_DESC> temporary_buffer_binding_desc;
};
GraphImplDml::GraphBufferBindingInfo::GraphBufferBindingInfo() = default;
GraphImplDml::GraphBufferBindingInfo::~GraphBufferBindingInfo() = default;
GraphImplDml::GraphBufferBindingInfo::GraphBufferBindingInfo(
const GraphBufferBindingInfo&) = default;
GraphImplDml::GraphBufferBindingInfo&
GraphImplDml::GraphBufferBindingInfo::operator=(const GraphBufferBindingInfo&) =
default;
GraphImplDml::GraphBufferBindingInfo::GraphBufferBindingInfo(
GraphBufferBindingInfo&&) = default;
GraphImplDml::GraphBufferBindingInfo&
GraphImplDml::GraphBufferBindingInfo::operator=(GraphBufferBindingInfo&&) =
default;
scoped_refptr<GraphImplDml::PersistentResource>
GraphImplDml::PersistentResource::Create(
uint64_t persistent_buffer_byte_length,
ComPtr<ID3D12Resource> persistent_buffer) {
CHECK_GT(persistent_buffer_byte_length, 0u);
CHECK_NE(persistent_buffer.Get(), nullptr);
return base::WrapRefCounted(new PersistentResource(
persistent_buffer_byte_length, std::move(persistent_buffer)));
}
GraphImplDml::PersistentResource::PersistentResource(
uint64_t persistent_buffer_byte_length,
ComPtr<ID3D12Resource> persistent_buffer)
: persistent_buffer_(std::move(persistent_buffer)) {
persistent_buffer_binding_ =
DML_BUFFER_BINDING{.Buffer = persistent_buffer_.Get(),
.Offset = 0,
.SizeInBytes = persistent_buffer_byte_length};
persistent_buffer_binding_desc_ = DML_BINDING_DESC{
.Type = DML_BINDING_TYPE_BUFFER, .Desc = &persistent_buffer_binding_};
}
GraphImplDml::PersistentResource::~PersistentResource() = default;
GraphImplDml::GraphResources::GraphResources(
ComPtr<ID3D12DescriptorHeap> descriptor_heap,
uint64_t temporary_buffer_byte_length,
ComPtr<ID3D12Resource> temporary_resource)
: descriptor_heap(std::move(descriptor_heap)),
temporary_buffer(std::move(temporary_resource)) {
if (temporary_buffer_byte_length > 0) {
CHECK_NE(temporary_buffer.Get(), nullptr);
temporary_buffer_binding =
DML_BUFFER_BINDING{.Buffer = temporary_buffer.Get(),
.Offset = 0,
.SizeInBytes = temporary_buffer_byte_length};
temporary_buffer_binding_desc =
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_BUFFER,
.Desc = &temporary_buffer_binding.value()};
}
}
GraphImplDml::GraphResources::~GraphResources() = default;
base::expected<std::unique_ptr<GraphImplDml::GraphResources>, HRESULT>
GraphImplDml::AllocateGraphResources(Adapter* adapter,
IDMLCompiledOperator* compiled_operator) {
TRACE_EVENT0("gpu", "GraphImplDml::AllocateGraphResources");
DML_BINDING_PROPERTIES execution_binding_properties =
compiled_operator->GetBindingProperties();
ComPtr<ID3D12DescriptorHeap> descriptor_heap;
RETURN_UNEXPECTED_IF_FAILED(CreateDescriptorHeap(
adapter->d3d12_device(),
execution_binding_properties.RequiredDescriptorCount,
L"WebNN_Descriptor_Heap_For_Execution", descriptor_heap));
ComPtr<ID3D12Resource> temporary_buffer;
uint64_t temporary_buffer_byte_length =
execution_binding_properties.TemporaryResourceSize;
if (temporary_buffer_byte_length > 0) {
RETURN_UNEXPECTED_IF_FAILED(CreateDefaultBuffer(
adapter->d3d12_device(), temporary_buffer_byte_length,
L"WebNN_Temporary_Buffer_For_Execution", temporary_buffer));
}
return base::WrapUnique(new GraphResources(std::move(descriptor_heap),
temporary_buffer_byte_length,
std::move(temporary_buffer)));
}
GraphImplDml::GraphImplDml(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
scoped_refptr<Adapter> adapter,
base::WeakPtr<WebNNContextImpl> context,
std::unique_ptr<CommandRecorder> command_recorder,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
ComputeResourceInfo compute_resource_info,
GraphBufferBindingInfo graph_buffer_binding_info,
std::unique_ptr<GraphResources> graph_resources,
std::vector<mojom::Device> devices)
: WebNNGraphImpl(std::move(receiver),
std::move(context),
std::move(compute_resource_info),
std::move(devices)),
persistent_resource_(std::move(persistent_resource)),
adapter_(std::move(adapter)),
command_recorder_(std::move(command_recorder)),
compiled_operator_(std::move(compiled_operator)),
graph_buffer_binding_info_(std::move(graph_buffer_binding_info)),
graph_resources_(std::move(graph_resources)) {}
GraphImplDml::~GraphImplDml() = default;
base::expected<ComPtr<IDMLCompiledOperator>, HRESULT>
GraphImplDml::CompileOnBackgroundThread(GraphBuilderDml graph_builder,
DML_EXECUTION_FLAGS flags) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CompileOnBackgroundThread");
return graph_builder.Compile(flags);
}
HRESULT GraphImplDml::ExecuteAndWaitSyncOnBackgroundThread(
std::unique_ptr<CommandRecorder> init_command_recorder_for_npu) {
TRACE_EVENT0("gpu",
"dml::GraphImplDml::ExecuteAndWaitSyncOnBackgroundThread");
RETURN_IF_FAILED(init_command_recorder_for_npu->Execute());
RETURN_IF_FAILED(init_command_recorder_for_npu->command_queue()->WaitSync());
return S_OK;
}
void GraphImplDml::OnCompilationComplete(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
WebNNContextImpl::CreateGraphImplCallback callback,
absl::flat_hash_map<OperandId, uint32_t> constant_id_to_input_index_map,
GraphBufferBindingInfo graph_buffer_binding_info,
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
base::expected<ComPtr<IDMLCompiledOperator>, HRESULT> compilation_result) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnCompilationComplete");
if (!context) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to create graph because the context was destroyed.")));
return;
}
if (!compilation_result.has_value()) {
if (adapter->IsNPU() &&
compilation_result.error() == DXGI_ERROR_UNSUPPORTED) {
LOG(ERROR)
<< "[WebNN] Failed to compile graph on NPU. Model is not supported.";
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to compile graph on NPU. Model is not supported.")));
} else {
HandleGraphCreationFailure("Failed to compile the graph.",
std::move(callback), context.get(),
compilation_result.error());
}
return;
}
ComPtr<IDMLCompiledOperator> compiled_operator =
std::move(compilation_result.value());
CommandQueue* command_queue = adapter->IsNPU()
? adapter->init_command_queue_for_npu()
: adapter->command_queue();
ASSIGN_OR_RETURN(
std::unique_ptr<CommandRecorder> initialization_command_recorder,
CommandRecorder::Create(command_queue, adapter->dml_device()),
&HandleGraphCreationFailure,
"Failed to create command recorder for graph initialization.",
std::move(callback), context.get());
HRESULT hr = initialization_command_recorder->Open();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to open the command recorder.",
std::move(callback), context.get(), hr);
return;
}
std::vector<DML_BUFFER_BINDING> input_buffer_binding(
graph_buffer_binding_info.input_buffer_binding_count,
DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0});
if (!constant_operands.empty()) {
std::optional<AlignedByteLength<OperandId>>
aligned_byte_length_of_constants =
CalculateAlignedByteLength(constant_operands);
if (!aligned_byte_length_of_constants) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to calculate the aligned byte length of constants.")));
return;
}
size_t total_byte_length_of_constants =
aligned_byte_length_of_constants.value().total_byte_length;
std::variant<UploadAndDefaultBuffers, ComPtr<ID3D12Resource>>
buffer_variant;
if (adapter->IsUMA()) {
ComPtr<ID3D12Resource> cpu_buffer;
hr = CreateCustomUploadBuffer(
adapter->d3d12_device(), total_byte_length_of_constants,
L"WebNN_Custom_Upload_Buffer_Constants", cpu_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create custom upload buffer for constants.",
std::move(callback), context.get(), hr);
return;
}
buffer_variant = std::move(cpu_buffer);
} else {
ComPtr<ID3D12Resource> upload_buffer;
hr = CreateUploadBuffer(adapter->d3d12_device(),
total_byte_length_of_constants,
L"WebNN_Upload_Buffer_Constants", upload_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create upload buffer for constants.",
std::move(callback), context.get(), hr);
return;
}
ComPtr<ID3D12Resource> default_buffer;
hr = CreateDefaultBuffer(
adapter->d3d12_device(), total_byte_length_of_constants,
L"WebNN_Default_Buffer_Constants", default_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create default input buffer for constants.",
std::move(callback), context.get(), hr);
return;
}
buffer_variant =
UploadAndDefaultBuffers{.upload_buffer = std::move(upload_buffer),
.default_buffer = std::move(default_buffer)};
}
ASSIGN_OR_RETURN(
(absl::flat_hash_map<OperandId, DML_BUFFER_BINDING>
constant_buffer_binding),
UploadAndCreateConstantBufferBinding(
initialization_command_recorder.get(), constant_operands,
aligned_byte_length_of_constants.value(),
std::move(buffer_variant)),
&HandleGraphCreationFailure, "Failed to upload constant weight data.",
std::move(callback), context.get());
for (auto& [constant_id, buffer_binding] : constant_buffer_binding) {
const auto graph_input_index_iterator =
constant_id_to_input_index_map.find(constant_id);
CHECK(graph_input_index_iterator != constant_id_to_input_index_map.end());
input_buffer_binding[graph_input_index_iterator->second] =
std::move(buffer_binding);
}
}
for (auto& [constant_id, constant_tensor] : constant_tensor_operands) {
TensorImplDml* constant_tensor_impl =
static_cast<TensorImplDml*>(constant_tensor);
const auto graph_input_index_iterator =
constant_id_to_input_index_map.find(constant_id);
CHECK(graph_input_index_iterator != constant_id_to_input_index_map.end());
input_buffer_binding[graph_input_index_iterator->second] =
DML_BUFFER_BINDING{
.Buffer = constant_tensor_impl->buffer(),
.Offset = 0,
.SizeInBytes = constant_tensor_impl->PackedByteLength()};
}
DML_BUFFER_ARRAY_BINDING input_buffer_array_binding{
.BindingCount = base::checked_cast<uint32_t>(input_buffer_binding.size()),
.Bindings = input_buffer_binding.data()};
DML_BINDING_DESC input_buffer_binding_desc = {DML_BINDING_TYPE_BUFFER_ARRAY,
&input_buffer_array_binding};
scoped_refptr<PersistentResource> persistent_resource;
std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc;
DML_BINDING_PROPERTIES execution_binding_properties =
compiled_operator->GetBindingProperties();
uint64_t persistent_buffer_size =
execution_binding_properties.PersistentResourceSize;
if (persistent_buffer_size) {
ComPtr<ID3D12Resource> persistent_buffer;
hr = CreateDefaultBuffer(adapter->d3d12_device(), persistent_buffer_size,
L"WebNN_Default_Persistent_Buffer",
persistent_buffer);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to create the default buffer for persistent resource.",
std::move(callback), context.get(), hr);
return;
}
persistent_resource = PersistentResource::Create(
persistent_buffer_size, std::move(persistent_buffer));
CHECK(persistent_resource);
persistent_buffer_binding_desc =
persistent_resource->persistent_buffer_binding_desc();
}
hr = initialization_command_recorder->InitializeOperator(
compiled_operator.Get(), input_buffer_binding_desc,
persistent_buffer_binding_desc);
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to initialize the operator.",
std::move(callback), context.get(), hr);
return;
}
hr = initialization_command_recorder->Close();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to close the command list.",
std::move(callback), context.get(), hr);
return;
}
if (adapter->IsNPU()) {
adapter->init_task_runner_for_npu()->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&GraphImplDml::ExecuteAndWaitSyncOnBackgroundThread,
std::move(initialization_command_recorder)),
base::BindOnce(
&GraphImplDml::OnInitializationComplete, std::move(receiver),
std::move(adapter), std::move(context),
std::move(persistent_resource), std::move(compiled_operator),
std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(callback)));
return;
}
hr = initialization_command_recorder->Execute();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to execute the command list.",
std::move(callback), context.get(), hr);
return;
}
initialization_command_recorder->command_queue()->WaitAsync(base::BindOnce(
&GraphImplDml::OnInitializationComplete, std::move(receiver),
std::move(adapter), std::move(context), std::move(persistent_resource),
std::move(compiled_operator), std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(callback)));
}
void GraphImplDml::CreateWebNNGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
ComputeResourceInfo compute_resource_info,
GraphBufferBindingInfo graph_buffer_binding_info,
WebNNContextImpl::CreateGraphImplCallback callback) {
if (!context) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to create graph because the context was destroyed.")));
return;
}
ASSIGN_OR_RETURN(
std::unique_ptr<CommandRecorder> command_recorder_for_dispatch,
CommandRecorder::Create(adapter->command_queue(), adapter->dml_device()),
&HandleGraphCreationFailure,
"Failed to create the command recorder for dispatch.",
std::move(callback), context.get());
ASSIGN_OR_RETURN(
std::unique_ptr<GraphResources> graph_resources,
AllocateGraphResources(adapter.get(), compiled_operator.Get()),
&HandleGraphCreationFailure,
"Failed to create the graph resource for dispatch.", std::move(callback),
context.get());
HRESULT hr = command_recorder_for_dispatch->Open();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to open the command recorder.",
std::move(callback), context.get(), hr);
return;
}
std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc;
if (persistent_resource) {
persistent_buffer_binding_desc =
persistent_resource->persistent_buffer_binding_desc();
}
hr = command_recorder_for_dispatch->ExecuteOperator(
compiled_operator, graph_resources->descriptor_heap,
persistent_buffer_binding_desc,
graph_resources->temporary_buffer_binding_desc);
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to record graph execution for late binding.",
std::move(callback), context.get(), hr);
return;
}
hr = command_recorder_for_dispatch->Close();
if (FAILED(hr)) {
HandleGraphCreationFailure("Failed to close the command recorder.",
std::move(callback), context.get(), hr);
return;
}
std::move(callback).Run(base::MakeRefCounted<GraphImplDml>(
std::move(receiver), std::move(adapter), context->AsWeakPtr(),
std::move(command_recorder_for_dispatch), std::move(persistent_resource),
std::move(compiled_operator), std::move(compute_resource_info),
std::move(graph_buffer_binding_info), std::move(graph_resources),
std::vector<mojom::Device>(
{adapter->IsNPU() ? mojom::Device::kNpu : mojom::Device::kGpu})));
}
void GraphImplDml::OnInitializationComplete(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
scoped_refptr<PersistentResource> persistent_resource,
ComPtr<IDMLCompiledOperator> compiled_operator,
ComputeResourceInfo compute_resource_info,
GraphBufferBindingInfo graph_buffer_binding_info,
WebNNContextImpl::CreateGraphImplCallback callback,
HRESULT hr) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnInitializationComplete");
if (!context) {
std::move(callback).Run(base::unexpected(CreateError(
mojom::Error::Code::kUnknownError,
"Failed to create graph because the context was destroyed.")));
return;
}
if (FAILED(hr)) {
HandleGraphCreationFailure(
"Failed to wait for the initialization to complete.",
std::move(callback), context.get(), hr);
return;
}
CreateWebNNGraphImpl(
std::move(receiver), std::move(adapter), std::move(context),
std::move(persistent_resource), std::move(compiled_operator),
std::move(compute_resource_info), std::move(graph_buffer_binding_info),
std::move(callback));
}
base::expected<void, mojom::ErrorPtr> GraphImplDml::CreateAndBuildInternal(
const ContextProperties& context_properties,
scoped_refptr<Adapter> adapter,
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
const base::flat_map<OperandId, WebNNTensorImpl*>& constant_tensor_operands,
GraphBuilderDml& graph_builder,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map,
GraphBufferBindingInfo& graph_buffer_binding_info) {
IdToNodeOutputMap id_to_node_output_map;
const auto& operands = graph_info->operands;
for (OperandId input_id : graph_info->input_operands) {
uint32_t graph_input_index = CreateInputNode(
operands, input_id, graph_builder, id_to_node_output_map);
const Operand& operand = GetOperand(operands, input_id);
graph_buffer_binding_info
.graph_input_name_to_index_map[operand.name.value()] =
graph_input_index;
}
base::flat_set<OperandId> constant_ids;
constant_ids.reserve(constant_operands.size());
for (const auto& [constant_id, _] : constant_operands) {
constant_ids.insert(constant_id);
}
for (OperandId constant_id : constant_ids) {
CreateConstantNode(adapter.get(), constant_id, constant_operands,
graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
}
for (const auto& [constant_id, tensor_impl] : constant_tensor_operands) {
const Node* node = graph_builder.CreateInputNode();
constant_id_to_input_index_map[constant_id] =
node->AsInputNode()->GetGraphInputIndex();
TensorDesc tensor_desc(GetTensorDataType(tensor_impl->data_type()),
DML_TENSOR_FLAG_OWNED_BY_DML, tensor_impl->shape());
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(tensor_desc));
CHECK(id_to_node_output_map.try_emplace(constant_id, output).second);
}
GraphFusionInfo graph_fusion_info = GetGraphFusionInfo(graph_info);
for (auto& operation : graph_info->operations) {
if (graph_fusion_info.fusible_operations_set.contains(operation.get())) {
continue;
}
base::expected<void, mojom::ErrorPtr> create_operator_result;
switch (operation->which()) {
case Operation::Tag::kArgMinMax: {
CreateOperatorNodeForArgMinMax(operands, operation->get_arg_min_max(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kBatchNormalization: {
CreateOperatorNodeForBatchNormalization(
adapter.get(), context_properties, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case Operation::Tag::kClamp: {
CreateOperatorNodeForClamp(adapter.get(), context_properties, operands,
operation->get_clamp(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kConcat: {
CreateOperatorNodeForConcat(context_properties, operands,
operation->get_concat(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kConv2d: {
CreateOperatorNodeForConv2d(
context_properties, operands, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kCumulativeSum: {
CreateOperatorNodeForCumulativeSum(
context_properties, operands, operation->get_cumulative_sum(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kDequantizeLinear: {
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_6_3)) {
create_operator_result =
CreateOperatorNodeForDequantizeOrQuantizeLinear<
DML_DEQUANTIZE_OPERATOR_DESC>(
context_properties, operands,
operation->get_dequantize_linear(), graph_builder,
DML_OPERATOR_DEQUANTIZE, id_to_node_output_map);
} else {
create_operator_result =
CreateOperatorNodeForDequantizeOrQuantizeLinear<
DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>(
context_properties, operands,
operation->get_dequantize_linear(), graph_builder,
DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR,
id_to_node_output_map);
}
break;
}
case mojom::Operation::Tag::kElementWiseBinary: {
CreateOperatorNodeForBinary(
context_properties, operands, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kElu: {
CreateOperatorNodeForElu(operands, operation->get_elu(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kElementWiseUnary: {
CreateOperatorNodeForElementWiseUnary(
context_properties, operands, operation->get_element_wise_unary(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kExpand: {
CreateOperatorNodeForExpand(context_properties, operands,
operation->get_expand(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGather: {
create_operator_result = CreateOperatorNodeForGather(
context_properties, operands, operation->get_gather(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGatherElements: {
CreateOperatorNodeForGatherElements(
context_properties, operands, operation->get_gather_elements(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGatherNd: {
CreateOperatorNodeForGatherND(context_properties, operands,
operation->get_gather_nd(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGelu: {
CreateOperatorNodeForGelu(
context_properties, adapter.get(), operation->get_gelu(),
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case mojom::Operation::Tag::kGemm: {
CreateOperatorNodeForGemm(
context_properties, operands, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kGru: {
create_operator_result = CreateOperatorNodeForGru<mojom::GruPtr>(
adapter.get(), context_properties, operation->get_gru(), graph_info,
constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case mojom::Operation::Tag::kGruCell: {
create_operator_result = CreateOperatorNodeForGru<mojom::GruCellPtr>(
adapter.get(), context_properties, operation->get_gru_cell(),
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case mojom::Operation::Tag::kHardSigmoid: {
CreateOperatorNodeForHardSigmoid(operands,
operation->get_hard_sigmoid(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kHardSwish: {
CreateOperatorNodeForHardSwish(adapter.get(), operands,
operation->get_hard_swish(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kInstanceNormalization: {
CHECK_EQ(context_properties.input_operand_layout,
InputOperandLayout::kNchw);
std::array<uint32_t, 2> mean_variance_axes = {2, 3};
std::array<uint32_t, 1> scale_bias_broadcast_axes = {1};
create_operator_result = CreateOperatorNodeForMeanVarianceNormalization(
adapter.get(), context_properties,
operation->get_instance_normalization(), operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, mean_variance_axes,
scale_bias_broadcast_axes, Operation::Tag::kInstanceNormalization);
break;
}
case Operation::Tag::kLayerNormalization: {
const auto& layer_normalization = operation->get_layer_normalization();
const auto axes = layer_normalization->axes;
create_operator_result = CreateOperatorNodeForMeanVarianceNormalization(
adapter.get(), context_properties, layer_normalization,
operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map, axes, axes,
Operation::Tag::kLayerNormalization);
break;
}
case Operation::Tag::kLeakyRelu: {
CreateOperatorNodeForLeakyRelu(operands, operation->get_leaky_relu(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kLinear: {
CreateOperatorNodeForLinear(context_properties, operands,
operation->get_linear(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kLstm: {
create_operator_result = CreateOperatorNodeForLstm<mojom::Lstm>(
adapter.get(), context_properties, *operation->get_lstm(),
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case Operation::Tag::kLstmCell: {
create_operator_result = CreateOperatorNodeForLstm<mojom::LstmCell>(
adapter.get(), context_properties, *operation->get_lstm_cell(),
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case mojom::Operation::Tag::kMatmul: {
create_operator_result = CreateOperatorNodeForMatmul(
context_properties, operands, operation.get(),
graph_fusion_info.operation_to_fusible_standalone_activation_map,
graph_fusion_info.output_id_to_fusible_transpose_map, graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kPad: {
CreateOperatorNodeForPad(context_properties, operands,
operation->get_pad(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kPool2d: {
create_operator_result = CreateOperatorNodeForPool2d(
context_properties, operands, operation->get_pool2d(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kPrelu: {
CreateOperatorNodeForPrelu(context_properties, operands,
operation->get_prelu(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kQuantizeLinear: {
if (adapter->IsDMLFeatureLevelSupported(DML_FEATURE_LEVEL_6_3)) {
create_operator_result =
CreateOperatorNodeForDequantizeOrQuantizeLinear<
DML_QUANTIZE_OPERATOR_DESC>(
context_properties, operands,
operation->get_quantize_linear(), graph_builder,
DML_OPERATOR_QUANTIZE, id_to_node_output_map);
} else {
create_operator_result =
CreateOperatorNodeForDequantizeOrQuantizeLinear<
DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>(
context_properties, operands,
operation->get_quantize_linear(), graph_builder,
DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR,
id_to_node_output_map);
}
break;
}
case Operation::Tag::kReduce: {
CreateOperatorNodeForReduce(context_properties, operands,
operation->get_reduce(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kRelu: {
CreateOperatorNodeForUnary<DML_ACTIVATION_RELU_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_RELU>(
operands, operation->get_relu(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kResample2d: {
CreateOperatorNodeForResample2d(context_properties, operands,
operation->get_resample2d(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kReshape: {
CreateOperatorNodeForReshape(context_properties, operands,
operation->get_reshape(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kReverse: {
CreateOperatorNodeForReverse(context_properties, operands,
*operation->get_reverse(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kScatterElements: {
CreateOperatorNodeForScatterElements(
context_properties, operands, operation->get_scatter_elements(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kScatterNd: {
CreateOperatorNodeForScatterND(context_properties, operands,
operation->get_scatter_nd(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kSigmoid: {
CreateOperatorNodeForUnary<DML_ACTIVATION_SIGMOID_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_SIGMOID>(
operands, operation->get_sigmoid(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kSlice: {
CreateOperatorNodeForSlice(operands, operation->get_slice(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kSoftmax: {
create_operator_result = CreateOperatorNodeForSoftmax(
adapter.get(), operands, operation->get_softmax(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kSoftplus: {
CreateOperatorNodeForSoftplus(operands, operation->get_softplus(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kSoftsign: {
CreateOperatorNodeForUnary<DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_SOFTSIGN>(
operands, operation->get_softsign(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kSplit: {
CreateOperatorNodeForSplit(operands, operation->get_split(),
graph_builder, id_to_node_output_map);
break;
}
case Operation::Tag::kTanh: {
CreateOperatorNodeForUnary<DML_ACTIVATION_TANH_OPERATOR_DESC,
DML_OPERATOR_ACTIVATION_TANH>(
operands, operation->get_tanh(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kTile: {
CreateOperatorNodeForTile(context_properties, operands,
operation->get_tile(), graph_builder,
id_to_node_output_map);
break;
}
case Operation::Tag::kTranspose: {
CreateOperatorNodeForTranspose(context_properties, operands,
operation->get_transpose(),
graph_builder, id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kTriangular: {
create_operator_result = CreateOperatorNodeForTriangular(
context_properties, adapter.get(), operation->get_triangular(),
graph_info, constant_operands, graph_builder, id_to_node_output_map,
constant_id_to_input_index_map);
break;
}
case Operation::Tag::kWhere: {
CreateOperatorNodeForWhere(operands, operation->get_where(),
graph_builder, id_to_node_output_map);
break;
}
default: {
std::string error_message = NotSupportedOperatorError(*operation);
create_operator_result = base::unexpected(CreateError(
mojom::Error::Code::kNotSupportedError, std::move(error_message)));
}
}
if (!create_operator_result.has_value()) {
return create_operator_result;
}
}
for (auto& output_id : graph_info->output_operands) {
const auto output_iterator = id_to_node_output_map.find(output_id);
CHECK(output_iterator != id_to_node_output_map.end());
const NodeOutput* output = output_iterator->second;
CHECK(output);
if (output->GetNode().GetType() == Node::Type::kInput) {
output = AppendIdentityNode(graph_builder, output);
}
std::string name = GetOperand(operands, output_id).name.value();
graph_buffer_binding_info.graph_output_name_to_index_map[std::move(name)] =
graph_builder.CreateOutputEdge(output);
}
graph_buffer_binding_info.input_buffer_binding_count =
constant_id_to_input_index_map.size() +
graph_buffer_binding_info.graph_input_name_to_index_map.size();
return base::ok();
}
void GraphImplDml::CreateAndBuild(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
scoped_refptr<Adapter> adapter,
base::WeakPtr<ContextImplDml> context,
mojom::GraphInfoPtr graph_info,
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
WebNNContextImpl::CreateGraphImplCallback callback,
const bool disable_dml_meta_commands_for_gpu) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CreateAndBuild");
GraphBuilderDml graph_builder(adapter->dml_device());
absl::flat_hash_map<OperandId, uint32_t> constant_id_to_input_index_map;
GraphBufferBindingInfo graph_buffer_binding_info;
base::expected<void, mojom::ErrorPtr> create_operator_result =
GraphImplDml::CreateAndBuildInternal(
context->properties(), adapter, graph_info, constant_operands,
constant_tensor_operands, graph_builder,
constant_id_to_input_index_map, graph_buffer_binding_info);
if (!create_operator_result.has_value()) {
std::move(callback).Run(
base::unexpected(std::move(create_operator_result.error())));
return;
}
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE;
if (disable_dml_meta_commands_for_gpu && !adapter->IsNPU()) {
flags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS;
}
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE,
{base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
base::BindOnce(&GraphImplDml::CompileOnBackgroundThread,
std::move(graph_builder), flags),
base::BindOnce(
&GraphImplDml::OnCompilationComplete, std::move(receiver),
std::move(adapter), std::move(context), std::move(callback),
std::move(constant_id_to_input_index_map),
std::move(graph_buffer_binding_info),
std::move(compute_resource_info), std::move(constant_operands),
std::move(constant_tensor_operands)));
}
void GraphImplDml::HandleDispatchFailure(std::string_view error_message,
HRESULT hr) {
command_recorder_.reset();
previous_input_tensors_.clear();
previous_output_tensors_.clear();
CHECK(context_);
static_cast<ContextImplDml*>(context_.get())
->HandleContextLostOrCrash(error_message, hr);
}
GraphImplDml::IoBindings::IoBindings(
std::vector<DML_BUFFER_BINDING> buffer_bindings,
base::FixedArray<DML_BINDING_DESC> buffer_binding_desc)
: buffer_bindings(std::move(buffer_bindings)),
buffer_binding_desc(std::move(buffer_binding_desc)) {}
GraphImplDml::IoBindings::~IoBindings() = default;
GraphImplDml::IoBindings GraphImplDml::CreateAndCacheInputBindings(
const base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>>&
named_inputs) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CreateAndCacheInputBindings");
std::vector<DML_BUFFER_BINDING> graph_input_buffer_bindings(
graph_buffer_binding_info_.input_buffer_binding_count,
DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0});
previous_input_tensors_.reserve(named_inputs.size());
base::FixedArray<DML_BINDING_DESC> input_buffer_binding_desc(
graph_buffer_binding_info_.input_buffer_binding_count,
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr});
for (auto& [name, input_tensor] : named_inputs) {
TensorImplDml* input_tensor_impl =
static_cast<TensorImplDml*>(input_tensor.get());
const size_t graph_input_index =
graph_buffer_binding_info_.graph_input_name_to_index_map.at(
std::string(name));
graph_input_buffer_bindings[graph_input_index] = DML_BUFFER_BINDING{
.Buffer = input_tensor_impl->buffer(),
.Offset = 0,
.SizeInBytes = input_tensor_impl->PackedByteLength()};
input_buffer_binding_desc[graph_input_index] = {
DML_BINDING_TYPE_BUFFER,
&graph_input_buffer_bindings[graph_input_index]};
previous_input_tensors_[std::string(name)] =
input_tensor_impl->GetWeakPtr();
}
return IoBindings(std::move(graph_input_buffer_bindings),
std::move(input_buffer_binding_desc));
}
GraphImplDml::IoBindings GraphImplDml::CreateAndCacheOutputBindings(
const base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>>&
named_outputs) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CreateAndCacheOutputBindings");
const size_t output_buffer_binding_count =
graph_buffer_binding_info_.graph_output_name_to_index_map.size();
std::vector<DML_BUFFER_BINDING> graph_output_buffer_bindings(
output_buffer_binding_count,
DML_BUFFER_BINDING{.Buffer = nullptr, .Offset = 0, .SizeInBytes = 0});
base::FixedArray<DML_BINDING_DESC> output_buffer_binding_desc(
output_buffer_binding_count,
DML_BINDING_DESC{.Type = DML_BINDING_TYPE_NONE, .Desc = nullptr});
previous_output_tensors_.reserve(named_outputs.size());
for (auto& [name, output_tensor] : named_outputs) {
TensorImplDml* output_tensor_impl =
static_cast<TensorImplDml*>(output_tensor.get());
const size_t graph_output_index =
graph_buffer_binding_info_.graph_output_name_to_index_map.at(
std::string(name));
graph_output_buffer_bindings[graph_output_index] = DML_BUFFER_BINDING{
.Buffer = output_tensor_impl->buffer(),
.Offset = 0,
.SizeInBytes = output_tensor_impl->PackedByteLength()};
output_buffer_binding_desc[graph_output_index] = {
DML_BINDING_TYPE_BUFFER,
&graph_output_buffer_bindings[graph_output_index]};
previous_output_tensors_[std::string(name)] =
output_tensor_impl->GetWeakPtr();
}
return IoBindings(std::move(graph_output_buffer_bindings),
std::move(output_buffer_binding_desc));
}
void GraphImplDml::DispatchImpl(
base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_inputs,
base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_outputs) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::DispatchImpl");
bool is_command_recording_needed = false;
if (!command_recorder_) {
ASSIGN_OR_RETURN(command_recorder_,
CommandRecorder::Create(adapter_->command_queue(),
adapter_->dml_device()),
&GraphImplDml::HandleDispatchFailure, this,
"Failed to create the command recorder.");
is_command_recording_needed = true;
}
std::unique_ptr<GraphResources> graph_resources = std::move(graph_resources_);
if (!graph_resources) {
base::expected<std::unique_ptr<GraphResources>, HRESULT> result =
AllocateGraphResources(adapter_.get(), compiled_operator_.Get());
if (!result.has_value()) {
HandleDispatchFailure("Failed to allocate graph resources.",
std::move(result.error()));
return;
}
graph_resources = std::move(result.value());
is_command_recording_needed = true;
}
CHECK(graph_resources);
HRESULT hr = S_OK;
bool is_inputs_binding_needed = false;
bool is_outputs_binding_needed = false;
if (is_command_recording_needed) {
hr = command_recorder_->Open();
if (FAILED(hr)) {
HandleDispatchFailure("Failed to open the command recorder.", hr);
return;
}
std::optional<DML_BINDING_DESC> persistent_buffer_binding_desc;
if (persistent_resource_) {
persistent_buffer_binding_desc =
persistent_resource_->persistent_buffer_binding_desc();
}
hr = command_recorder_->ExecuteOperator(
compiled_operator_.Get(), graph_resources->descriptor_heap,
persistent_buffer_binding_desc,
graph_resources->temporary_buffer_binding_desc);
if (FAILED(hr)) {
HandleDispatchFailure("Failed to record execute operator.", hr);
return;
}
hr = command_recorder_->Close();
if (FAILED(hr)) {
HandleDispatchFailure("Failed to close the command recorder.", hr);
return;
}
is_inputs_binding_needed = true;
is_outputs_binding_needed = true;
} else {
if (!IsDispatchBindingValid(named_inputs, previous_input_tensors_)) {
is_inputs_binding_needed = true;
}
if (!IsDispatchBindingValid(named_outputs, previous_output_tensors_)) {
is_outputs_binding_needed = true;
}
}
if (is_inputs_binding_needed) {
IoBindings input_bindings = CreateAndCacheInputBindings(named_inputs);
hr = command_recorder_->BindInputs(input_bindings.buffer_binding_desc);
if (FAILED(hr)) {
HandleDispatchFailure("Failed to bind inputs.", hr);
return;
}
}
if (is_outputs_binding_needed) {
IoBindings output_bindings = CreateAndCacheOutputBindings(named_outputs);
hr = command_recorder_->BindOutputs(output_bindings.buffer_binding_desc);
if (FAILED(hr)) {
HandleDispatchFailure("Failed to bind outputs.", hr);
return;
}
}
hr = command_recorder_->Execute();
if (FAILED(hr)) {
HandleDispatchFailure("Failed to execute the command recorder.", hr);
return;
}
CommandQueue* command_queue = command_recorder_->command_queue();
uint64_t last_submitted_fence_value = command_queue->GetLastFenceValue();
for (auto& [name, input_tensor] : named_inputs) {
auto* dml_input_tensor = static_cast<TensorImplDml*>(input_tensor.get());
dml_input_tensor->SetLastSubmissionFenceValue(last_submitted_fence_value);
command_queue->ReferenceUntilCompleted(dml_input_tensor->buffer());
}
for (auto& [name, output_tensor] : named_outputs) {
auto* dml_output_tensor = static_cast<TensorImplDml*>(output_tensor.get());
dml_output_tensor->SetLastSubmissionFenceValue(last_submitted_fence_value);
command_queue->ReferenceUntilCompleted(dml_output_tensor->buffer());
}
command_queue->WaitAsync(base::BindOnce(&GraphImplDml::OnDispatchComplete,
weak_factory_.GetWeakPtr(),
std::move(graph_resources)));
}
void GraphImplDml::OnDispatchComplete(
std::unique_ptr<GraphResources> graph_resources,
HRESULT hr) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnDispatchComplete");
if (FAILED(hr)) {
HandleDispatchFailure("Failed to wait for the dispatch to complete.", hr);
return;
}
}
}