#ifndef SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_
#define SERVICES_WEBNN_PUBLIC_CPP_GRAPH_VALIDATION_UTILS_H_
#include <optional>
#include <variant>
#include <vector>
#include "base/component_export.h"
#include "base/containers/enum_set.h"
#include "base/containers/span.h"
#include "base/types/expected.h"
#include "services/webnn/public/cpp/context_properties.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
namespace webnn {
enum class Conv2dFilterOperandLayout { kOihw, kHwio, kOhwi, kIhwo };
enum class ConvTranspose2dFilterOperandLayout { kIohw, kHwoi, kOhwi };
enum class Pool2dKind { kAverage, kL2, kMax };
enum class RoundingType { kFloor, kCeil };
enum class RecurrentNetworkDirection { kForward, kBackward, kBoth };
enum class PaddingMode { kConstant, kEdge, kReflection };
enum class ReduceKind {
kL1,
kL2,
kLogSum,
kLogSumExp,
kMax,
kMean,
kMin,
kProduct,
kSum,
kSumSquare
};
template <typename T>
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Size2d {
T height;
T width;
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Padding2d {
Size2d<uint32_t> beginning;
Size2d<uint32_t> ending;
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) BatchNormalizationAttributes {
BatchNormalizationAttributes();
~BatchNormalizationAttributes();
BatchNormalizationAttributes(BatchNormalizationAttributes&& other);
BatchNormalizationAttributes& operator=(BatchNormalizationAttributes&& other);
BatchNormalizationAttributes(const BatchNormalizationAttributes&) = delete;
BatchNormalizationAttributes& operator=(const BatchNormalizationAttributes&) =
delete;
std::optional<OperandDescriptor> scale;
std::optional<OperandDescriptor> bias;
uint32_t axis = 1;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Conv2dAttributesBase {
Conv2dAttributesBase();
~Conv2dAttributesBase();
Conv2dAttributesBase(Conv2dAttributesBase&& other);
Conv2dAttributesBase& operator=(Conv2dAttributesBase&& other);
Conv2dAttributesBase(const Conv2dAttributesBase&) = delete;
Conv2dAttributesBase& operator=(const Conv2dAttributesBase&) = delete;
Padding2d padding;
Size2d<uint32_t> strides;
Size2d<uint32_t> dilations;
uint32_t groups = 1;
InputOperandLayout input_layout = InputOperandLayout::kNchw;
std::optional<OperandDescriptor> bias_operand;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Conv2dAttributes
: Conv2dAttributesBase {
Conv2dAttributes();
~Conv2dAttributes();
Conv2dAttributes(Conv2dAttributes&& other);
Conv2dAttributes& operator=(Conv2dAttributes&& other);
Conv2dAttributes(const Conv2dAttributes&) = delete;
Conv2dAttributes& operator=(const Conv2dAttributes&) = delete;
Conv2dFilterOperandLayout filter_layout = Conv2dFilterOperandLayout::kOihw;
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) ConvTranspose2dAttributes
: Conv2dAttributesBase {
ConvTranspose2dAttributes();
~ConvTranspose2dAttributes();
ConvTranspose2dAttributes(ConvTranspose2dAttributes&& other);
ConvTranspose2dAttributes& operator=(ConvTranspose2dAttributes&& other);
ConvTranspose2dAttributes(const ConvTranspose2dAttributes&) = delete;
ConvTranspose2dAttributes& operator=(const ConvTranspose2dAttributes&) =
delete;
Size2d<uint32_t> output_padding;
std::optional<Size2d<uint32_t>> output_sizes;
ConvTranspose2dFilterOperandLayout filter_layout =
ConvTranspose2dFilterOperandLayout::kIohw;
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GemmAttributes {
GemmAttributes();
~GemmAttributes();
GemmAttributes(GemmAttributes&& other);
GemmAttributes& operator=(GemmAttributes&& other);
GemmAttributes(const GemmAttributes&) = delete;
GemmAttributes& operator=(const GemmAttributes&) = delete;
std::optional<OperandDescriptor> c_operand;
float alpha = 1.0;
float beta = 1.0;
bool a_transpose = false;
bool b_transpose = false;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GruAttributes {
GruAttributes();
~GruAttributes();
GruAttributes(GruAttributes&& other);
GruAttributes& operator=(GruAttributes&& other);
GruAttributes(const GruAttributes&) = delete;
GruAttributes& operator=(const GruAttributes&) = delete;
std::optional<OperandDescriptor> bias;
std::optional<OperandDescriptor> recurrent_bias;
std::optional<OperandDescriptor> initial_hidden_state;
bool return_sequence;
RecurrentNetworkDirection direction;
uint32_t activation_count;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) GruCellAttributes {
GruCellAttributes();
~GruCellAttributes();
GruCellAttributes(GruCellAttributes&& other);
GruCellAttributes& operator=(GruCellAttributes&& other);
GruCellAttributes(const GruCellAttributes&) = delete;
GruCellAttributes& operator=(const GruCellAttributes&) = delete;
std::optional<OperandDescriptor> bias;
std::optional<OperandDescriptor> recurrent_bias;
uint32_t activation_count;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) InstanceNormalizationAttributes {
InstanceNormalizationAttributes();
~InstanceNormalizationAttributes();
InstanceNormalizationAttributes(InstanceNormalizationAttributes&& other);
InstanceNormalizationAttributes& operator=(
InstanceNormalizationAttributes&& other);
InstanceNormalizationAttributes(const InstanceNormalizationAttributes&) =
delete;
InstanceNormalizationAttributes& operator=(
const InstanceNormalizationAttributes&) = delete;
std::optional<OperandDescriptor> scale;
std::optional<OperandDescriptor> bias;
InputOperandLayout layout = InputOperandLayout::kNchw;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LayerNormalizationAttributes {
LayerNormalizationAttributes();
~LayerNormalizationAttributes();
LayerNormalizationAttributes(LayerNormalizationAttributes&& other);
LayerNormalizationAttributes& operator=(LayerNormalizationAttributes&& other);
LayerNormalizationAttributes(const LayerNormalizationAttributes&) = delete;
LayerNormalizationAttributes& operator=(const LayerNormalizationAttributes&) =
delete;
std::optional<OperandDescriptor> scale;
std::optional<OperandDescriptor> bias;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LstmAttributes {
LstmAttributes();
~LstmAttributes();
LstmAttributes(LstmAttributes&& other);
LstmAttributes& operator=(LstmAttributes&& other);
LstmAttributes(const LstmAttributes&) = delete;
LstmAttributes& operator=(const LstmAttributes&) = delete;
std::optional<OperandDescriptor> bias;
std::optional<OperandDescriptor> recurrent_bias;
std::optional<OperandDescriptor> peephole_weight;
std::optional<OperandDescriptor> initial_hidden_state;
std::optional<OperandDescriptor> initial_cell_state;
size_t activation_count;
bool return_sequence;
RecurrentNetworkDirection direction;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) LstmCellAttributes {
LstmCellAttributes();
~LstmCellAttributes();
LstmCellAttributes(LstmCellAttributes&& other);
LstmCellAttributes& operator=(LstmCellAttributes&& other);
LstmCellAttributes(const LstmCellAttributes&) = delete;
LstmCellAttributes& operator=(const LstmCellAttributes&) = delete;
std::optional<OperandDescriptor> bias;
std::optional<OperandDescriptor> recurrent_bias;
std::optional<OperandDescriptor> peephole_weight;
size_t activation_count;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) Pool2dAttributes {
Pool2dAttributes();
~Pool2dAttributes();
Pool2dAttributes(Pool2dAttributes&& other);
Pool2dAttributes& operator=(Pool2dAttributes&& other);
Pool2dAttributes(const Pool2dAttributes&) = delete;
Pool2dAttributes& operator=(const Pool2dAttributes&) = delete;
std::optional<Size2d<uint32_t>> window_dimensions;
Padding2d padding;
Size2d<uint32_t> strides;
Size2d<uint32_t> dilations;
InputOperandLayout layout = InputOperandLayout::kNchw;
RoundingType rounding_type = RoundingType::kFloor;
std::optional<Size2d<uint32_t>> output_sizes;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) SliceAttributes {
SliceAttributes();
~SliceAttributes();
SliceAttributes(SliceAttributes&& other);
SliceAttributes& operator=(SliceAttributes&& other);
SliceAttributes(const SliceAttributes&) = delete;
SliceAttributes& operator=(const SliceAttributes&) = delete;
std::vector<uint32_t> starts;
std::vector<uint32_t> sizes;
std::vector<uint32_t> strides;
std::string label = "";
};
struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) SplitAttribute {
std::variant<uint32_t, base::span<const uint32_t>> splits;
uint32_t axis = 0;
std::string label = "";
};
base::expected<double, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
CalculateConv2dOutputSize(uint32_t input_size,
uint32_t filter_size,
uint32_t beginning_padding,
uint32_t ending_padding,
uint32_t stride,
uint32_t dilation,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateArgMinMaxAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
std::string_view label,
uint32_t axis,
OperandDataType output_data_type,
bool keep_dimensions = false);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateBatchNormalizationAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& mean,
const OperandDescriptor& variance,
const BatchNormalizationAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateCastAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
OperandDataType output_data_type,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateConcatAndInferOutput(const ContextProperties& context_properties,
const std::vector<OperandDescriptor>& input,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateConv2dAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& filter,
const Conv2dAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateConvTranspose2dAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& filter,
const ConvTranspose2dAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateCumulativeSumAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateDequantizeLinearAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& scale,
const OperandDescriptor& zero_point,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateQuantizeLinearAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& scale,
const OperandDescriptor& zero_point,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateExpandAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> new_shape,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGatherAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGatherElementsAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGatherNDAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGemmAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& a,
const OperandDescriptor& b,
const GemmAttributes& attributes);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGruAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
uint32_t steps,
uint32_t hidden_size,
const GruAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateGruCellAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const OperandDescriptor& hidden_state,
uint32_t hidden_size,
const GruCellAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateInstanceNormalizationAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const InstanceNormalizationAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateLayerNormalizationAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> axes,
const LayerNormalizationAttributes& attributes);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateLstmAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const uint32_t steps,
const uint32_t hidden_size,
const LstmAttributes& attributes);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateLstmCellAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& weight,
const OperandDescriptor& recurrent_weight,
const OperandDescriptor& hidden_state,
const OperandDescriptor& cell_state,
const uint32_t hidden_size,
const LstmCellAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateMatmulAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& a,
const OperandDescriptor& b,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidatePadAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> beginning_padding,
base::span<const uint32_t> ending_padding,
PaddingMode mode,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidatePool2dAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const Pool2dAttributes& attributes,
Pool2dKind kind);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidatePreluAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& slope,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateReduceAndInferOutput(const ContextProperties& context_properties,
ReduceKind kind,
const OperandDescriptor& input,
std::string_view label,
base::span<const uint32_t> axes,
bool keepDimensions = false);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateResample2dAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const std::variant<base::span<const float>, base::span<const uint32_t>>&
scales_or_sizes,
base::span<const uint32_t> axes,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateReverseAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> axes,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateScatterElementsAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const OperandDescriptor& updates,
uint32_t axis,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateScatterNDAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const OperandDescriptor& updates,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateSliceAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const SliceAttributes& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateSoftmaxAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
uint32_t axis,
std::string_view label);
base::expected<std::vector<OperandDescriptor>, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateSplitAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
const SplitAttribute& attributes);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateTileAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> repetitions,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateTransposeAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& input,
base::span<const uint32_t> permutation,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateTriangularAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
std::string_view label);
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateWhereAndInferOutput(const ContextProperties& context_properties,
const OperandDescriptor& condition,
const OperandDescriptor& true_value,
const OperandDescriptor& false_value,
std::string_view label);
base::expected<void, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
ValidateAxes(base::span<const uint32_t> axes,
uint32_t rank,
std::string_view label);
base::expected<void, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
ValidateTensor(const ContextProperties& context_properties,
OperandDescriptor descriptor);
std::optional<std::vector<uint32_t>> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
BroadcastShapes(base::span<const uint32_t> dims_lhs,
base::span<const uint32_t> dims_rhs,
bool bidirectional = true);
base::expected<uint32_t, std::string> COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
CalculateConvTranspose2dOutputSize(const uint32_t input_size,
const uint32_t filter_size,
const uint32_t beginning_padding,
const uint32_t ending_padding,
const uint32_t stride,
const uint32_t dilation,
const uint32_t output_padding);
bool COMPONENT_EXPORT(WEBNN_PUBLIC_CPP)
IsDepthwiseConv2d(uint32_t input_channels,
uint32_t output_channels,
uint32_t groups);
}
#endif