#include "services/webnn/webnn_graph_impl.h"
#include <cmath>
#include <limits>
#include <memory>
#include "base/containers/contains.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "base/test/run_until.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "mojo/public/cpp/system/functions.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/ml_tensor_usage.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_device.mojom-data-view.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_tensor_impl.h"
#include "services/webnn/webnn_test_environment.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace webnn {
namespace {
class FakeWebNNGraphImpl final : public WebNNGraphImpl {
public:
FakeWebNNGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
base::WeakPtr<WebNNContextImpl> context,
ComputeResourceInfo compute_resource_info)
: WebNNGraphImpl(std::move(receiver),
std::move(context),
std::move(compute_resource_info),
{}) {}
static void CreateAndBuild(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
base::WeakPtr<WebNNContextImpl> context,
const mojom::GraphInfo& graph_info,
ComputeResourceInfo compute_resource_info,
WebNNContextImpl::CreateGraphImplCallback callback) {
std::move(callback).Run(base::MakeRefCounted<FakeWebNNGraphImpl>(
std::move(receiver), std::move(context),
std::move(compute_resource_info)));
}
private:
~FakeWebNNGraphImpl() override = default;
void DispatchImpl(
base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_inputs,
base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_outputs)
override {}
};
class FakeWebNNTensorImpl final : public WebNNTensorImpl {
public:
FakeWebNNTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
base::WeakPtr<WebNNContextImpl> context,
mojom::TensorInfoPtr tensor_info)
: WebNNTensorImpl(std::move(receiver),
std::move(context),
std::move(tensor_info)) {}
private:
~FakeWebNNTensorImpl() override = default;
void ReadTensorImpl(ReadTensorCallback callback) override {}
void WriteTensorImpl(mojo_base::BigBuffer src_buffer) override {}
bool ImportTensorImpl() override { return false; }
void ExportTensorImpl(ScopedAccessPtr access,
ExportTensorCallback callback) override {}
};
class FakeWebNNContextImpl final : public WebNNContextImpl {
public:
FakeWebNNContextImpl(
mojo::PendingReceiver<mojom::WebNNContext> receiver,
base::WeakPtr<WebNNContextProviderImpl> context_provider,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::MemoryTracker> memory_tracker,
scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
gpu::SharedImageManager* shared_image_manager,
scoped_refptr<base::SingleThreadTaskRunner> main_task_runner)
: WebNNContextImpl(std::move(receiver),
std::move(context_provider),
GetContextPropertiesForTesting(),
mojom::CreateContextOptions::New(),
mojo::ScopedDataPipeConsumerHandle(),
mojo::ScopedDataPipeProducerHandle(),
command_buffer_id,
std::move(sequence),
std::move(memory_tracker),
std::move(owning_task_runner),
shared_image_manager,
std::move(main_task_runner)) {}
base::WeakPtr<WebNNContextImpl> AsWeakPtr() override {
DCHECK_CALLED_ON_VALID_SEQUENCE(gpu_sequence_checker_);
return weak_factory_.GetWeakPtr();
}
private:
~FakeWebNNContextImpl() override = default;
void CreateGraphImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
mojom::GraphInfoPtr graph_info,
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<
OperandId,
std::unique_ptr<WebNNConstantOperand>> ,
base::flat_map<OperandId, WebNNTensorImpl*> ,
CreateGraphImplCallback callback) override {
FakeWebNNGraphImpl::CreateAndBuild(
std::move(receiver), AsWeakPtr(), *graph_info,
std::move(compute_resource_info), std::move(callback));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
CreateTensorImpl(mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info) override {
return base::MakeRefCounted<FakeWebNNTensorImpl>(
std::move(receiver), AsWeakPtr(), std::move(tensor_info));
}
base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
CreateTensorFromSharedImageImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info,
WebNNTensorImpl::RepresentationPtr representation) override {
return base::unexpected(mojom::Error::New(
mojom::Error::Code::kNotSupportedError, "Not implemented"));
}
base::WeakPtrFactory<FakeWebNNContextImpl> weak_factory_{this};
};
class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting {
public:
std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter> CreateWebNNContext(
base::WeakPtr<WebNNContextProviderImpl> context_provider_impl,
mojom::CreateContextOptionsPtr options,
gpu::CommandBufferId command_buffer_id,
std::unique_ptr<ScopedSequence> sequence,
scoped_refptr<gpu::MemoryTracker> memory_tracker,
scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
gpu::SharedImageManager* shared_image_manager,
scoped_refptr<base::SingleThreadTaskRunner> main_task_runner,
mojom::WebNNContextProvider::CreateWebNNContextCallback callback)
override {
mojo::PendingRemote<mojom::WebNNContext> remote;
auto task_runner = owning_task_runner;
std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter> context_impl(
new FakeWebNNContextImpl(
remote.InitWithNewPipeAndPassReceiver(),
std::move(context_provider_impl), command_buffer_id,
std::move(sequence), std::move(memory_tracker),
std::move(owning_task_runner), shared_image_manager,
std::move(main_task_runner)),
OnTaskRunnerDeleter(std::move(task_runner)));
ContextProperties context_properties = context_impl->properties();
auto success = mojom::CreateContextSuccess::New(
std::move(remote), std::move(context_properties),
context_impl->handle(), mojo::ScopedDataPipeProducerHandle(),
mojo::ScopedDataPipeConsumerHandle());
std::move(callback).Run(
mojom::CreateContextResult::NewSuccess(std::move(success)));
return context_impl;
}
};
struct CreateTensorSuccess {
std::optional<mojo::AssociatedRemote<mojom::WebNNTensor>> webnn_tensor;
blink::WebNNTensorToken webnn_handle;
};
CreateTensorSuccess CreateWebNNTensor(
mojo::Remote<mojom::WebNNContext>& webnn_context,
OperandDataType data_type,
std::vector<uint32_t> shape) {
base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future;
webnn_context->CreateTensor(
mojom::TensorInfo::New(
OperandDescriptor::UnsafeCreateForTesting(data_type, shape),
MLTensorUsage()),
mojo_base::BigBuffer(0), create_tensor_future.GetCallback());
mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();
mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor;
webnn_tensor.Bind(
std::move(create_tensor_result->get_success()->tensor_remote));
return CreateTensorSuccess{
std::move(webnn_tensor),
std::move(create_tensor_result->get_success()->tensor_handle)};
}
mojo::Remote<mojom::WebNNContext> CreateWebNNContext(
mojo::Remote<mojom::WebNNContextProvider>& webnn_context_provider) {
base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
webnn_context_provider->CreateWebNNContext(
mojom::CreateContextOptions::New(), create_context_future.GetCallback());
mojom::CreateContextResultPtr create_context_result =
create_context_future.Take();
mojo::Remote<mojom::WebNNContext> webnn_context;
webnn_context.Bind(
std::move(create_context_result->get_success()->context_remote));
return webnn_context;
}
bool ValidateDispatch(
mojo::Remote<mojom::WebNNContext>& webnn_context,
mojom::GraphInfoPtr graph_info,
base::flat_map<std::string, CreateTensorSuccess> inputs,
base::flat_map<std::string, CreateTensorSuccess> outputs) {
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> graph_builder_remote;
webnn_context->CreateGraphBuilder(
graph_builder_remote.BindNewEndpointAndPassReceiver());
base::test::TestFuture<
base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
create_graph_future;
graph_builder_remote->CreateGraph(std::move(graph_info),
create_graph_future.GetCallback());
base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>
create_graph_result = create_graph_future.Take();
mojo::AssociatedRemote<mojom::WebNNGraph> webnn_graph;
webnn_graph.Bind(std::move(create_graph_result.value()->graph_remote));
bool valid = true;
mojo::SetDefaultProcessErrorHandler(
base::BindLambdaForTesting([&](const std::string& error_message) {
EXPECT_EQ(error_message, kBadMessageInvalidTensor);
valid = false;
}));
base::flat_map<std::string, blink::WebNNTensorToken> dispatch_inputs;
for (const auto& [name, tensor_info] : inputs) {
dispatch_inputs.emplace(name, tensor_info.webnn_handle);
}
base::flat_map<std::string, blink::WebNNTensorToken> dispatch_outputs;
for (const auto& [name, tensor_info] : outputs) {
dispatch_outputs.emplace(name, tensor_info.webnn_handle);
}
webnn_context.FlushForTesting();
webnn_graph->Dispatch(dispatch_inputs, dispatch_outputs);
webnn_graph.FlushForTesting();
mojo::SetDefaultProcessErrorHandler(base::NullCallback());
return valid;
}
OperandDataType kAllOperandDataTypes[] = {
OperandDataType::kFloat32, OperandDataType::kFloat16,
OperandDataType::kInt32, OperandDataType::kInt8,
OperandDataType::kUint8,
};
}
class WebNNGraphImplTest : public testing::Test {
public:
WebNNGraphImplTest(const WebNNGraphImplTest&) = delete;
WebNNGraphImplTest& operator=(const WebNNGraphImplTest&) = delete;
void SetUp() override {
WebNNContextProviderImpl::SetBackendForTesting(&backend_for_testing_);
webnn_test_environment_.BindWebNNContextProvider(
provider_remote_.BindNewPipeAndPassReceiver());
base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
provider_remote_->CreateWebNNContext(mojom::CreateContextOptions::New(),
create_context_future.GetCallback());
mojom::CreateContextResultPtr create_context_result =
create_context_future.Take();
webnn_context_.Bind(
std::move(create_context_result->get_success()->context_remote));
}
void TearDown() override {
WebNNContextProviderImpl::SetBackendForTesting(nullptr);
}
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote() {
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote;
webnn_context_->CreateGraphBuilder(remote.BindNewEndpointAndPassReceiver());
return remote;
}
protected:
WebNNGraphImplTest()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
~WebNNGraphImplTest() override = default;
private:
base::test::ScopedFeatureList scoped_feature_list_;
base::test::TaskEnvironment task_environment_;
FakeWebNNBackend backend_for_testing_;
test::WebNNTestEnvironment webnn_test_environment_;
mojo::Remote<mojom::WebNNContextProvider> provider_remote_;
mojo::Remote<mojom::WebNNContext> webnn_context_;
};
struct OperandInfo {
OperandDataType type;
std::vector<uint32_t> dimensions;
};
struct ArgMinMaxTester {
mojom::ArgMinMax::Kind kind;
OperandInfo input;
uint32_t axis;
bool keep_dimensions = false;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildArgMinMax(kind, input_operand_id, output_operand_id, axis,
keep_dimensions);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ArgMinMaxTest) {
const auto ArgMinMaxKinds = {mojom::ArgMinMax_Kind::kMin,
mojom::ArgMinMax_Kind::kMax};
for (const auto kind : ArgMinMaxKinds) {
{
ArgMinMaxTester{.kind = kind,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axis = 0,
.keep_dimensions = true,
.output = {.type = OperandDataType::kInt32,
.dimensions = {1, 3, 4, 5}},
.expected = true}
.Test(*this);
}
{
ArgMinMaxTester{
.kind = kind,
.input = {.type = OperandDataType::kFloat16,
.dimensions = {2, 3, 4, 5}},
.axis = 1,
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 4, 5}},
.expected = true}
.Test(*this);
}
{
ArgMinMaxTester{.kind = kind,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axis = 4,
.keep_dimensions = true,
.output = {.type = OperandDataType::kInt32,
.dimensions = {2, 3, 4, 1}},
.expected = false}
.Test(*this);
}
{
ArgMinMaxTester{.kind = kind,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axis = 0,
.keep_dimensions = true,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 5}},
.expected = false}
.Test(*this);
}
{
ArgMinMaxTester{.kind = kind,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axis = 0,
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32,
.dimensions = {1, 3, 4, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3, 4, 5}, OperandDataType::kInt32);
builder.BuildArgMinMax(kind, input_operand_id, input_operand_id,
0,
true);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
}
struct ClampTester {
OperandInfo input;
struct ClampAttributes {
float min_value;
float max_value;
};
ClampAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildClamp(input_operand_id, output_operand_id,
attributes.min_value, attributes.max_value);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ClampTest) {
{
ClampTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 4}},
.attributes = {.min_value = 0.0, .max_value = 6.0},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 4}},
.expected = true}
.Test(*this);
}
{
ClampTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.attributes = {.min_value = static_cast<float>(-1.0 / 0.0),
.max_value = 3.0},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.expected = true}
.Test(*this);
}
{
ClampTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.attributes = {.min_value = 0.0,
.max_value = static_cast<float>(1.0 / 0.0)},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.expected = true}
.Test(*this);
}
{
ClampTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 2, 7}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 2, 7}},
.expected = true}
.Test(*this);
}
{
ClampTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.attributes = {.min_value = NAN, .max_value = 3.0},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ClampTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.attributes = {.min_value = -3.0, .max_value = NAN},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ClampTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.attributes = {.min_value = 7.0, .max_value = 3.0},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.expected = false}
.Test(*this);
}
{
ClampTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
ClampTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
struct HardSigmoidTester {
OperandInfo input;
std::optional<float> alpha;
std::optional<float> beta;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildHardSigmoid(input_operand_id, output_operand_id, alpha, beta);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, HardSigmoidTest) {
{
HardSigmoidTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = true}
.Test(*this);
}
{
HardSigmoidTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.alpha = NAN,
.beta = 0.5,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.expected = false}
.Test(*this);
}
{
HardSigmoidTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
.alpha = 1.0,
.beta = NAN,
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
.expected = false}
.Test(*this);
}
{
HardSigmoidTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
HardSigmoidTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
struct BatchNormalizationTester {
OperandInfo input;
OperandInfo mean;
OperandInfo variance;
std::optional<OperandInfo> scale;
std::optional<OperandInfo> bias;
struct BatchNormalizationAttributes {
std::optional<OperandId> scale_operand_id;
std::optional<OperandId> bias_operand_id;
uint32_t axis = 1;
float epsilon = 1e-5;
};
BatchNormalizationAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId mean_operand_id =
builder.BuildInput("mean", mean.dimensions, mean.type);
OperandId variance_operand_id =
builder.BuildInput("variance", variance.dimensions, variance.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
if (scale) {
attributes.scale_operand_id =
builder.BuildInput("scale", scale->dimensions, scale->type);
}
if (bias) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
builder.BuildBatchNormalization(input_operand_id, mean_operand_id,
variance_operand_id, output_operand_id,
std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, BatchNormalizationTest) {
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {3}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {3}},
.attributes = {.axis = 3},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.scale =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kInt32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {3}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kInt32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kInt32, .dimensions = {2}},
.variance = {.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {3}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {3}},
.attributes = {.axis = 4},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.scale =
OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.scale =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.bias = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
BatchNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId mean_operand_id =
builder.BuildInput("mean", {2}, OperandDataType::kFloat32);
OperandId variance_operand_id =
builder.BuildInput("variance", {2}, OperandDataType::kFloat32);
builder.BuildBatchNormalization(
input_operand_id, mean_operand_id, variance_operand_id,
input_operand_id,
BatchNormalizationTester::BatchNormalizationAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId mean_operand_id =
builder.BuildInput("mean", {2}, OperandDataType::kFloat32);
OperandId variance_operand_id =
builder.BuildInput("variance", {2}, OperandDataType::kFloat32);
builder.BuildBatchNormalization(
input_operand_id, mean_operand_id, variance_operand_id, mean_operand_id,
BatchNormalizationTester::BatchNormalizationAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId mean_operand_id =
builder.BuildInput("mean", {2}, OperandDataType::kFloat32);
OperandId variance_operand_id =
builder.BuildInput("variance", {2}, OperandDataType::kFloat32);
builder.BuildBatchNormalization(
input_operand_id, mean_operand_id, variance_operand_id,
variance_operand_id,
BatchNormalizationTester::BatchNormalizationAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ConcatTester {
std::vector<OperandInfo> inputs;
uint32_t axis;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
std::vector<OperandId> input_operand_ids;
input_operand_ids.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
input_operand_ids.push_back(
builder.BuildInput(base::StringPrintf("input%zu", i),
inputs[i].dimensions, inputs[i].type));
}
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildConcat(std::move(input_operand_ids), output_operand_id, axis);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ConcatTest) {
{
ConcatTester{
.inputs =
{{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5, 6}},
{.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5, 6}},
{.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5, 6}}},
.axis = 1,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 6, 5, 6}},
.expected = true}
.Test(*this);
}
{
ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}}},
.axis = 1,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
.expected = true}
.Test(*this);
}
{
ConcatTester{.inputs = {},
.axis = 0,
.output = {.type = OperandDataType::kInt32, .dimensions = {1}},
.expected = false}
.Test(*this);
}
{
ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = OperandDataType::kInt32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 3, 5, 6}},
.expected = false}
.Test(*this);
}
{
ConcatTester{
.inputs = {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5}},
{.type = OperandDataType::kFloat32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5}},
.expected = false}
.Test(*this);
}
{
ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}}},
.axis = 4,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 12}},
.expected = false}
.Test(*this);
}
{
ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 1}}},
.axis = 1,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 2, 5, 7}},
.expected = false}
.Test(*this);
}
{
ConcatTester{
.inputs = {{.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = OperandDataType::kFloat32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = OperandDataType::kInt32, .dimensions = {3, 3, 5, 6}},
.expected = false}
.Test(*this);
}
{
ConcatTester{
.inputs = {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 2}},
{.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2}}},
.axis = 0,
.output = {.type = OperandDataType::kFloat32, .dimensions = {5, 1, 2}},
.expected = false}
.Test(*this);
}
}
struct Conv2dTester {
mojom::Conv2d::Kind type;
OperandInfo input;
OperandInfo filter;
struct Conv2dAttributes {
std::vector<uint32_t> padding = {0, 0, 0, 0};
std::vector<uint32_t> strides = {1, 1};
std::vector<uint32_t> dilations = {1, 1};
uint32_t groups = 1;
std::optional<OperandInfo> bias;
};
Conv2dAttributes attributes;
InputOperandLayout input_operand_layout = InputOperandLayout::kNchw;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
context_properties.input_operand_layout = input_operand_layout;
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId filter_operand_id =
builder.BuildInput("filter", filter.dimensions, filter.type);
std::optional<OperandId> bias_operand_id;
if (attributes.bias) {
bias_operand_id = builder.BuildInput("bias", attributes.bias->dimensions,
attributes.bias->type);
}
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildConv2d(type, input_operand_id, filter_operand_id,
output_operand_id, std::move(attributes),
bias_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, Conv2dTest) {
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 4, 2, 2}},
.filter = {.type = OperandDataType::kFloat16,
.dimensions = {4, 1, 2, 2}},
.attributes = {.groups = 4},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 4, 1, 1}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 5, 5}},
.filter = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 3}},
.input_operand_layout = InputOperandLayout::kNchw,
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 3, 3}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias = OperandInfo{.type = OperandDataType::kInt32,
.dimensions = {1}}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {2}}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 1, 1}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
OperandId filter_operand_id =
builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
filter_operand_id, input_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
OperandId filter_operand_id =
builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
filter_operand_id, filter_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) {
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 3, 1}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 3, 1}},
.input_operand_layout = InputOperandLayout::kNhwc,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 5, 5, 1}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.attributes = {.strides = {2, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 7, 7}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 5, 5}},
.expected = true}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {3, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 5, 5}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.groups = 3},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias = OperandInfo{.type = OperandDataType::kInt32,
.dimensions = {1}}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test(*this);
}
{
Conv2dTester{
.type = mojom::Conv2d::Kind::kTransposed,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.filter = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.attributes = {.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {2}}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 5, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32);
OperandId filter_operand_id =
builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d::Kind::kTransposed, input_operand_id,
filter_operand_id, input_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32);
OperandId filter_operand_id =
builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);
builder.BuildConv2d(mojom::Conv2d::Kind::kTransposed, input_operand_id,
filter_operand_id, filter_operand_id,
Conv2dTester::Conv2dAttributes{}, std::nullopt);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct CumulativeSumTester {
OperandInfo input;
uint32_t axis;
std::optional<bool> exclusive;
std::optional<bool> reversed;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildCumulativeSum(input_operand_id, output_operand_id, axis,
exclusive, reversed);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, CumulativeSumTest) {
{
CumulativeSumTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.axis = 0,
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = true}
.Test(*this);
}
{
CumulativeSumTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.axis = 0,
.exclusive = true,
.reversed = true,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.expected = true}
.Test(*this);
}
{
CumulativeSumTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.axis = 2,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.expected = true}
.Test(*this);
}
{
CumulativeSumTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {}},
.axis = 0,
.output = {.type = OperandDataType::kFloat16, .dimensions = {}},
.expected = false}
.Test(*this);
}
{
CumulativeSumTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.axis = 3,
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
.expected = false}
.Test(*this);
}
{
CumulativeSumTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.axis = 2,
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
uint32_t axis = 0;
bool exclusive = false;
bool reversed = false;
OperandId input_operand_id =
builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32);
builder.BuildCumulativeSum(input_operand_id, input_operand_id, axis,
exclusive, reversed);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct DequantizeLinearTester {
OperandInfo input;
OperandInfo scale;
OperandInfo zero_point;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId scale_operand_id =
builder.BuildInput("scale", scale.dimensions, scale.type);
OperandId zero_point_operand_id = builder.BuildInput(
"zero_point", zero_point.dimensions, zero_point.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, DequantizeLinearTest) {
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 1}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.zero_point = {.type = OperandDataType::kUint8,
.dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
DequantizeLinearTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kInt8);
OperandId scale_operand_id =
builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
OperandId zero_point_operand_id =
builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kInt8);
OperandId scale_operand_id =
builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
OperandId zero_point_operand_id =
builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, scale_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kInt8);
OperandId scale_operand_id =
builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
OperandId zero_point_operand_id =
builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, zero_point_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ElementWiseBinaryTester {
mojom::ElementWiseBinary::Kind kind;
OperandInfo lhs;
OperandInfo rhs;
OperandInfo output;
bool expected;
static constexpr std::array<mojom::ElementWiseBinary::Kind, 16>
kAllBinaryOps = {
mojom::ElementWiseBinary::Kind::kAdd,
mojom::ElementWiseBinary::Kind::kSub,
mojom::ElementWiseBinary::Kind::kMul,
mojom::ElementWiseBinary::Kind::kDiv,
mojom::ElementWiseBinary::Kind::kPow,
mojom::ElementWiseBinary::Kind::kMax,
mojom::ElementWiseBinary::Kind::kMin,
mojom::ElementWiseBinary::Kind::kEqual,
mojom::ElementWiseBinary::Kind::kGreater,
mojom::ElementWiseBinary::Kind::kGreaterOrEqual,
mojom::ElementWiseBinary::Kind::kLesser,
mojom::ElementWiseBinary::Kind::kLesserOrEqual,
mojom::ElementWiseBinary::Kind::kNotEqual,
mojom::ElementWiseBinary::Kind::kLogicalAnd,
mojom::ElementWiseBinary::Kind::kLogicalOr,
mojom::ElementWiseBinary::Kind::kLogicalXor,
};
static OperandDataType GetValidInputType(mojom::ElementWiseBinary::Kind op) {
switch (op) {
case mojom::ElementWiseBinary::Kind::kAdd:
case mojom::ElementWiseBinary::Kind::kSub:
case mojom::ElementWiseBinary::Kind::kMul:
case mojom::ElementWiseBinary::Kind::kDiv:
case mojom::ElementWiseBinary::Kind::kPow:
case mojom::ElementWiseBinary::Kind::kMax:
case mojom::ElementWiseBinary::Kind::kMin:
case mojom::ElementWiseBinary::Kind::kEqual:
case mojom::ElementWiseBinary::Kind::kGreater:
case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
case mojom::ElementWiseBinary::Kind::kLesser:
case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
case mojom::ElementWiseBinary::Kind::kNotEqual:
return OperandDataType::kFloat32;
case mojom::ElementWiseBinary::Kind::kLogicalAnd:
case mojom::ElementWiseBinary::Kind::kLogicalOr:
case mojom::ElementWiseBinary::Kind::kLogicalXor:
return OperandDataType::kUint8;
}
}
static OperandDataType GetValidOutputType(mojom::ElementWiseBinary::Kind op) {
switch (op) {
case mojom::ElementWiseBinary::Kind::kAdd:
case mojom::ElementWiseBinary::Kind::kSub:
case mojom::ElementWiseBinary::Kind::kMul:
case mojom::ElementWiseBinary::Kind::kDiv:
case mojom::ElementWiseBinary::Kind::kPow:
case mojom::ElementWiseBinary::Kind::kMax:
case mojom::ElementWiseBinary::Kind::kMin:
return OperandDataType::kFloat32;
case mojom::ElementWiseBinary::Kind::kEqual:
case mojom::ElementWiseBinary::Kind::kGreater:
case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
case mojom::ElementWiseBinary::Kind::kLesser:
case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
case mojom::ElementWiseBinary::Kind::kNotEqual:
case mojom::ElementWiseBinary::Kind::kLogicalAnd:
case mojom::ElementWiseBinary::Kind::kLogicalOr:
case mojom::ElementWiseBinary::Kind::kLogicalXor:
return OperandDataType::kUint8;
}
}
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId lhs_operand_id =
builder.BuildInput("lhs", lhs.dimensions, lhs.type);
OperandId rhs_operand_id =
builder.BuildInput("rhs", rhs.dimensions, rhs.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id,
output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
void TestLogicalOperators(WebNNGraphImplTest& test) {
const mojom::ElementWiseBinary::Kind kLogicalOperators[] = {
mojom::ElementWiseBinary::Kind::kEqual,
mojom::ElementWiseBinary::Kind::kGreater,
mojom::ElementWiseBinary::Kind::kGreaterOrEqual,
mojom::ElementWiseBinary::Kind::kLesser,
mojom::ElementWiseBinary::Kind::kLesserOrEqual,
mojom::ElementWiseBinary::Kind::kNotEqual,
};
for (const auto& op : kLogicalOperators) {
kind = op;
Test(test);
}
}
};
TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) {
for (const auto& op : ElementWiseBinaryTester::kAllBinaryOps) {
const OperandDataType valid_input_type =
ElementWiseBinaryTester::GetValidInputType(op);
const OperandDataType valid_output_type =
ElementWiseBinaryTester::GetValidOutputType(op);
{
ElementWiseBinaryTester{
.kind = op,
.lhs = {.type = valid_input_type, .dimensions = {8, 1, 6, 1}},
.rhs = {.type = valid_input_type, .dimensions = {7, 1, 5}},
.output = {.type = valid_output_type, .dimensions = {8, 7, 6, 5}},
.expected = true}
.Test(*this);
}
{
ElementWiseBinaryTester{
.kind = op,
.lhs = {.type = valid_input_type, .dimensions = {4, 2, 1}},
.rhs = {.type = valid_input_type, .dimensions = {4}},
.output = {.type = valid_output_type, .dimensions = {4, 2, 4}},
.expected = true}
.Test(*this);
}
{
ElementWiseBinaryTester{
.kind = op,
.lhs = {.type = valid_input_type, .dimensions = {4, 2}},
.rhs = {.type = valid_input_type, .dimensions = {4}},
.output = {.type = valid_output_type, .dimensions = {4, 2}},
.expected = false}
.Test(*this);
}
{
ElementWiseBinaryTester{
.kind = op,
.lhs = {.type = valid_input_type, .dimensions = {4, 2}},
.rhs = {.type = valid_input_type, .dimensions = {4, 2}},
.output = {.type = valid_output_type, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
ElementWiseBinaryTester{
.kind = op,
.lhs = {.type = valid_input_type, .dimensions = {2}},
.rhs = {.type = OperandDataType::kInt64, .dimensions = {2}},
.output = {.type = valid_output_type, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
ElementWiseBinaryTester{
.kind = op,
.lhs = {.type = valid_input_type, .dimensions = {2}},
.rhs = {.type = valid_input_type, .dimensions = {2}},
.output = {.type = OperandDataType::kInt64, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
}
struct ElementWiseUnaryTester {
mojom::ElementWiseUnary::Kind kind;
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildElementWiseUnary(kind, input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
class ElementWiseUnaryDataTypeFixture
: public WebNNGraphImplTest,
public testing::WithParamInterface<
std::tuple<std::pair<mojom::ElementWiseUnary::Kind,
std::vector<OperandDataType>>,
OperandDataType,
OperandDataType>> {
public:
struct PrintToStringParamName {
template <class ParamType>
std::string operator()(
const testing::TestParamInfo<ParamType>& info) const {
std::string test_name =
base::StrCat({OpKindToString(std::get<0>(info.param).first), "_",
DataTypeToString(std::get<1>(info.param)), "_",
DataTypeToString(std::get<2>(info.param))});
return test_name;
}
};
void TestDataTypeSupportWithDimensions(
const std::vector<uint32_t>& dimensions) {
auto [operator_trait, inputDataType, outputDataType] = GetParam();
const mojom::ElementWiseUnary::Kind& kind = operator_trait.first;
const std::set<mojom::ElementWiseUnary::Kind>
kOperatorsWithDissimilarDatatypeSupport = {
mojom::ElementWiseUnary::Kind::kCast};
const bool expected =
(inputDataType == outputDataType ||
kOperatorsWithDissimilarDatatypeSupport.contains(kind)) &&
base::Contains(operator_trait.second, inputDataType);
ElementWiseUnaryTester{
.kind = kind,
.input = {.type = inputDataType, .dimensions = dimensions},
.output = {.type = outputDataType, .dimensions = dimensions},
.expected = expected}
.Test(*this);
}
};
TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) {
TestDataTypeSupportWithDimensions(std::vector<uint32_t>{1, 2, 3, 1});
}
TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) {
TestDataTypeSupportWithDimensions(std::vector<uint32_t>{});
}
INSTANTIATE_TEST_SUITE_P(
WebNNGraphImplTest,
ElementWiseUnaryDataTypeFixture,
::testing::Combine(
::testing::ValuesIn({
std::make_pair(mojom::ElementWiseUnary::Kind::kLogicalNot,
std::vector<OperandDataType>{
OperandDataType::kUint8}),
std::make_pair(
mojom::ElementWiseUnary::Kind::kIdentity,
std::vector<OperandDataType>(kAllOperandDataTypes,
std::end(kAllOperandDataTypes))),
std::make_pair(mojom::ElementWiseUnary::Kind::kSqrt,
std::vector<OperandDataType>{
OperandDataType::kFloat16,
OperandDataType::kFloat32}),
std::make_pair(mojom::ElementWiseUnary::Kind::kErf,
std::vector<OperandDataType>{
OperandDataType::kFloat16,
OperandDataType::kFloat32}),
std::make_pair(mojom::ElementWiseUnary::Kind::kReciprocal,
std::vector<OperandDataType>{
OperandDataType::kFloat16,
OperandDataType::kFloat32}),
std::make_pair(
mojom::ElementWiseUnary::Kind::kCast,
std::vector<OperandDataType>(kAllOperandDataTypes,
std::end(kAllOperandDataTypes))),
}),
::testing::ValuesIn(kAllOperandDataTypes),
::testing::ValuesIn(kAllOperandDataTypes)),
ElementWiseUnaryDataTypeFixture::PrintToStringParamName());
TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) {
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kAbs,
.input = {.type = OperandDataType::kFloat32, .dimensions = {1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCeil,
.input = {.type = OperandDataType::kFloat16, .dimensions = {1}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {1}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCos,
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kExp,
.input = {.type = OperandDataType::kFloat16, .dimensions = {1, 2}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {1, 2}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kFloor,
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kLog,
.input = {.type = OperandDataType::kFloat16, .dimensions = {1, 2, 3}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {1, 2, 3}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kNeg,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kSin,
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kTan,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.expected = true}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kAbs,
.input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint32,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kNeg,
.input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCeil,
.input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint32,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCos,
.input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint32,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kExp,
.input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kFloor,
.input = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kLog,
.input = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kSin,
.input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint32,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kTan,
.input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kUint32,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kAbs,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kCeil,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
ElementWiseUnaryTester{
.kind = mojom::ElementWiseUnary::Kind::kCast,
.input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 1}},
.output = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 2}},
.expected = false}
.Test(*this);
}
}
struct EluTester {
OperandInfo input;
OperandInfo output;
float alpha = 1.0;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildElu(input_operand_id, output_operand_id, alpha);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, EluTest) {
{
EluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
.expected = true}
.Test(*this);
}
{
EluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.alpha = NAN,
.expected = false}
.Test(*this);
}
{
EluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
EluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
EluTester{.input = {.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
builder.BuildElu(input_operand_id, input_operand_id, 1.0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ExpandTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildExpand(input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ExpandTest) {
{
ExpandTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
.expected = true}
.Test(*this);
}
{
ExpandTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {3, 1, 5}},
.output = {.type = OperandDataType::kInt32, .dimensions = {3, 4, 5}},
.expected = true}
.Test(*this);
}
{
ExpandTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 5}},
.output = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
ExpandTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 6, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 3, 5}},
.expected = false}
.Test(*this);
}
{
ExpandTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {5}},
.output = {.type = OperandDataType::kInt32, .dimensions = {5, 4}},
.expected = false}
.Test(*this);
}
{
ExpandTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
builder.BuildExpand(input_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
static constexpr SupportedRanks kRankLimit = SupportedRanks::UpTo(4);
context_properties.data_type_limits.expand_input.ranks.IntersectWith(
kRankLimit);
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", {1, 1, 1, 1, 2}, OperandDataType::kFloat32);
builder.BuildExpand(input_operand_id, output_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct GatherAttributes {
OperandInfo indices;
uint32_t axis;
};
struct GatherTester {
OperandInfo input;
GatherAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId indices_operand_id = builder.BuildInput(
"indices", attributes.indices.dimensions, attributes.indices.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildGather(input_operand_id, indices_operand_id, output_operand_id,
attributes.axis);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GatherTest) {
{
GatherTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {6, 7}},
.axis = 1},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 6, 7, 5}},
.expected = true}
.Test(*this);
}
{
GatherTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {6, 7}},
.axis = 3},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {3, 4, 5, 6, 7}},
.expected = false}
.Test(*this);
}
{
GatherTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kFloat16,
.dimensions = {6, 7}},
.axis = 1},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 6, 7, 5}},
.expected = false}
.Test(*this);
}
{
GatherTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kFloat32,
.dimensions = {6, 7}},
.axis = 1},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 6, 7, 5}},
.expected = false}
.Test(*this);
}
{
GatherTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {6, 7}},
.axis = 1},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 4, 6, 7, 5}},
.expected = false}
.Test(*this);
}
{
GatherTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {6, 7}},
.axis = 1},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {3, 6, 7, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId indices_operand_id =
builder.BuildInput("indices", {2}, OperandDataType::kUint32);
builder.BuildGather(input_operand_id, indices_operand_id, input_operand_id,
0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {3}, OperandDataType::kUint32);
OperandId indices_operand_id =
builder.BuildInput("indices", {3}, OperandDataType::kUint32);
builder.BuildGather(input_operand_id, indices_operand_id,
indices_operand_id, 0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct GatherElementsTester {
OperandInfo input;
GatherAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId indices_operand_id = builder.BuildInput(
"indices", attributes.indices.dimensions, attributes.indices.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildGatherElements(input_operand_id, indices_operand_id,
output_operand_id, attributes.axis);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GatherElementsTest) {
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {3, 4, 5, 6}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {3, 4, 2, 6}},
.axis = 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 4, 2, 6}},
.expected = true}
.Test(*this);
}
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {3, 4, 5}},
.axis = 3},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.expected = false}
.Test(*this);
}
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {3, 4}},
.axis = 2},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {3, 3, 5}},
.axis = 2},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5}},
.expected = false}
.Test(*this);
}
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kFloat16,
.dimensions = {3, 4, 5}},
.axis = 0},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.expected = false}
.Test(*this);
}
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {3, 1, 5}},
.axis = 1},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.expected = false}
.Test(*this);
}
{
GatherElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
.attributes = {.indices = {.type = OperandDataType::kUint32,
.dimensions = {3, 1, 5}},
.axis = 1},
.output = {.type = OperandDataType::kFloat16, .dimensions = {3, 1, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId indices_operand_id =
builder.BuildInput("indices", {2, 3}, OperandDataType::kUint32);
builder.BuildGatherElements(input_operand_id, indices_operand_id,
input_operand_id,
0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {3}, OperandDataType::kUint32);
OperandId indices_operand_id =
builder.BuildInput("indices", {3}, OperandDataType::kUint32);
builder.BuildGatherElements(input_operand_id, indices_operand_id,
indices_operand_id, 0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct GatherNDTester {
OperandInfo input;
OperandInfo indices;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId indices_operand_id =
builder.BuildInput("indices", indices.dimensions, indices.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildGatherND(input_operand_id, indices_operand_id,
output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GatherNDTest) {
{
GatherNDTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {3, 4, 5, 6}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {3, 7, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 7, 5, 6}},
.expected = true}
.Test(*this);
}
{
GatherNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {1, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {}},
.expected = false}
.Test(*this);
}
{
GatherNDTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4, 5}},
.expected = false}
.Test(*this);
}
{
GatherNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {1, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.expected = false}
.Test(*this);
}
{
GatherNDTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {1, 1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 4}},
.expected = false}
.Test(*this);
}
{
GatherNDTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {1, 1}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kUint32);
OperandId indices_operand_id =
builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32);
builder.BuildGatherND(input_operand_id, indices_operand_id,
input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 1}, OperandDataType::kUint32);
OperandId indices_operand_id =
builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32);
builder.BuildGatherND(input_operand_id, indices_operand_id,
indices_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct GeluTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildGelu(input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GeluTest) {
{
GeluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
.expected = true}
.Test(*this);
}
{
GeluTester{.input = {.type = OperandDataType::kInt32, .dimensions = {}},
.output = {.type = OperandDataType::kInt32, .dimensions = {}},
.expected = false}
.Test(*this);
}
{
GeluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
GeluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1}, OperandDataType::kFloat16);
builder.BuildGelu(input_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct GemmTester {
OperandInfo a;
OperandInfo b;
std::optional<OperandInfo> c;
struct GemmAttributes {
std::optional<OperandId> c_operand_id;
float alpha = 1.0;
float beta = 1.0;
bool a_transpose = false;
bool b_transpose = false;
};
GemmAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId a_operand_id = builder.BuildInput("a", a.dimensions, a.type);
OperandId b_operand_id = builder.BuildInput("b", b.dimensions, b.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
if (c) {
attributes.c_operand_id = builder.BuildInput("c", c->dimensions, c->type);
}
builder.BuildGemm(a_operand_id, b_operand_id, output_operand_id,
std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GemmTest) {
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = true}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.attributes = {.a_transpose = true},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = true}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {4, 3}},
.attributes = {.b_transpose = true},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = true}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.c = OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = true}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.c = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {4, 3}},
.c = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2, 4}},
.attributes = {.a_transpose = true, .b_transpose = true},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
GemmTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
}
struct GruTester {
struct GruAttributes {
std::optional<OperandId> bias_operand_id;
std::optional<OperandId> recurrent_bias_operand_id;
std::optional<OperandId> initial_hidden_state_operand_id;
bool reset_after = true;
bool return_sequence = false;
mojom::RecurrentNetworkDirection direction =
mojom::RecurrentNetworkDirection::kForward;
mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn;
std::vector<mojom::RecurrentNetworkActivation> activations = {
mojom::RecurrentNetworkActivation::kSigmoid,
mojom::RecurrentNetworkActivation::kTanh};
};
OperandInfo input;
OperandInfo weight;
OperandInfo recurrent_weight;
uint32_t steps;
uint32_t hidden_size;
std::optional<OperandInfo> bias;
std::optional<OperandInfo> recurrent_bias;
std::optional<OperandInfo> initial_hidden_state;
GruAttributes attributes;
std::vector<OperandInfo> outputs;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId weight_operand_id =
builder.BuildInput("weight", weight.dimensions, weight.type);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
std::vector<OperandId> output_operand_ids;
output_operand_ids.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
output_operand_ids.push_back(
builder.BuildOutput(base::StringPrintf("output%zu", i),
outputs[i].dimensions, outputs[i].type));
}
if (bias.has_value()) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
if (recurrent_bias.has_value()) {
attributes.recurrent_bias_operand_id = builder.BuildInput(
"recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
}
if (initial_hidden_state.has_value()) {
attributes.initial_hidden_state_operand_id = builder.BuildInput(
"initialHiddenState", initial_hidden_state->dimensions,
initial_hidden_state->type);
}
builder.BuildGru(input_operand_id, weight_operand_id,
recurrent_weight_operand_id, std::move(output_operand_ids),
steps, hidden_size, std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GruTest) {
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t num_directions = 2;
GruTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size, input_size}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size}},
.recurrent_bias =
OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size}},
.initial_hidden_state =
OperandInfo{
.type = OperandDataType::kFloat32,
.dimensions = {num_directions, batch_size, hidden_size}},
.attributes = {.reset_after = true,
.return_sequence = true,
.direction = mojom::RecurrentNetworkDirection::kBoth},
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {num_directions, batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {steps, num_directions, batch_size,
hidden_size}}},
.expected = true}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t num_directions = 1;
GruTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 4 * hidden_size, input_size}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {num_directions, batch_size, hidden_size}}},
.expected = false}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t num_directions = 1;
GruTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size, input_size}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {num_directions, batch_size,
3 * hidden_size}}},
.expected = false}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t num_directions = 1;
GruTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size, input_size}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {num_directions, 3 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {num_directions, batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {steps, num_directions, batch_size,
hidden_size}}},
.expected = false}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t num_directions = 1;
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id = builder.BuildInput(
"input", {steps, batch_size, input_size}, OperandDataType::kFloat32);
OperandId weight_operand_id = builder.BuildInput(
"weight", {num_directions, 3 * hidden_size, input_size},
OperandDataType::kFloat32);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", {num_directions, 3 * hidden_size, hidden_size},
OperandDataType::kFloat32);
OperandId initial_hidden_state_operand_id = builder.BuildInput(
"initialHiddenState", {num_directions, batch_size, hidden_size},
OperandDataType::kFloat32);
builder.BuildGru(
input_operand_id, weight_operand_id, recurrent_weight_operand_id,
{initial_hidden_state_operand_id}, steps, hidden_size,
GruTester::GruAttributes{.initial_hidden_state_operand_id =
initial_hidden_state_operand_id});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct GruCellTester {
struct GruCellAttributes {
std::optional<OperandId> bias_operand_id;
std::optional<OperandId> recurrent_bias_operand_id;
bool reset_after = true;
mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn;
std::vector<mojom::RecurrentNetworkActivation> activations = {
mojom::RecurrentNetworkActivation::kSigmoid,
mojom::RecurrentNetworkActivation::kTanh};
};
OperandInfo input;
OperandInfo weight;
OperandInfo recurrent_weight;
OperandInfo hidden_state;
uint32_t hidden_size;
std::optional<OperandInfo> bias;
std::optional<OperandInfo> recurrent_bias;
GruCellAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId weight_operand_id =
builder.BuildInput("weight", weight.dimensions, weight.type);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
OperandId hidden_state_operand_id = builder.BuildInput(
"hiddenState", hidden_state.dimensions, hidden_state.type);
if (bias.has_value()) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
if (recurrent_bias.has_value()) {
attributes.recurrent_bias_operand_id = builder.BuildInput(
"recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
}
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildGruCell(input_operand_id, weight_operand_id,
recurrent_weight_operand_id, hidden_state_operand_id,
output_operand_id, hidden_size, std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, GruCellTest) {
uint32_t batch_size = 2;
uint32_t input_size = 4;
uint32_t hidden_size = 6;
OperandInfo valid_input = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, input_size}};
OperandInfo valid_weight = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size, input_size}};
OperandInfo valid_recurrent_weight = {
.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size, hidden_size}};
OperandInfo valid_hidden_state = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size}};
OperandInfo valid_bias = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size}};
OperandInfo valid_recurrent_bias = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size}};
OperandInfo valid_output = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size}};
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = true}
.Test(*this);
}
{
GruCellTester{.input = {.type = OperandDataType::kInt8,
.dimensions = {batch_size, input_size}},
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, input_size}},
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {input_size}},
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = {.type = OperandDataType::kInt8,
.dimensions = {3 * hidden_size, input_size}},
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size, input_size}},
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size}},
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = {.type = OperandDataType::kInt8,
.dimensions = {3 * hidden_size, hidden_size}},
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size, input_size}},
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size}},
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = 1000,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = OperandInfo{.type = OperandDataType::kUint8,
.dimensions = {3 * hidden_size}},
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size}},
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size, hidden_size}},
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = OperandInfo{.type = OperandDataType::kUint8,
.dimensions = {3 * hidden_size}},
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size}},
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias =
OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size, hidden_size}},
.attributes = {.reset_after = true},
.output = valid_output,
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = {.type = OperandDataType::kInt32,
.dimensions = {batch_size, hidden_size}},
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, 3 * hidden_size}},
.expected = false}
.Test(*this);
}
{
GruCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.attributes = {.reset_after = true},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {hidden_size}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id = builder.BuildInput(
"input", {batch_size, input_size}, OperandDataType::kFloat32);
OperandId weight_operand_id = builder.BuildInput(
"weight", {3 * hidden_size, input_size}, OperandDataType::kFloat32);
OperandId recurrent_weight_operand_id =
builder.BuildInput("recurrentWeight", {3 * hidden_size, hidden_size},
OperandDataType::kFloat32);
OperandId hidden_state_operand_id = builder.BuildInput(
"hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32);
builder.BuildGruCell(input_operand_id, weight_operand_id,
recurrent_weight_operand_id, hidden_state_operand_id,
hidden_state_operand_id, hidden_size,
GruCellTester::GruCellAttributes{.reset_after = true});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct InstanceNormalizationTester {
OperandInfo input;
std::optional<OperandInfo> scale;
std::optional<OperandInfo> bias;
struct InstanceNormalizationAttributes {
std::optional<OperandId> scale_operand_id;
std::optional<OperandId> bias_operand_id;
float epsilon = 1e-5;
};
InstanceNormalizationAttributes attributes;
InputOperandLayout input_operand_layout = InputOperandLayout::kNchw;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
context_properties.input_operand_layout = input_operand_layout;
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
if (scale) {
attributes.scale_operand_id =
builder.BuildInput("scale", scale->dimensions, scale->type);
}
if (bias) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
builder.BuildInstanceNormalization(input_operand_id, output_operand_id,
std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, InstanceNormalizationTest) {
{
InstanceNormalizationTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.scale =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.input_operand_layout = InputOperandLayout::kNhwc,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.scale =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = true}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.scale =
OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.scale =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.bias = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.bias =
OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
.input_operand_layout = InputOperandLayout::kNhwc,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
InstanceNormalizationTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 3, 3}},
.expected = false}
.Test(*this);
}
{
InstanceNormalizationTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
builder.BuildInstanceNormalization(
input_operand_id, input_operand_id,
InstanceNormalizationTester::InstanceNormalizationAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId scale_operand_id =
builder.BuildInput("scale", {2}, OperandDataType::kFloat32);
InstanceNormalizationTester::InstanceNormalizationAttributes attributes;
attributes.scale_operand_id = scale_operand_id;
builder.BuildInstanceNormalization(input_operand_id, scale_operand_id,
std::move(attributes));
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId bias_operand_id =
builder.BuildInput("bias", {2}, OperandDataType::kFloat32);
InstanceNormalizationTester::InstanceNormalizationAttributes attributes;
attributes.bias_operand_id = bias_operand_id;
builder.BuildInstanceNormalization(input_operand_id, bias_operand_id,
std::move(attributes));
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct LayerNormalizationTester {
OperandInfo input;
std::optional<OperandInfo> scale;
std::optional<OperandInfo> bias;
struct LayerNormalizationAttributes {
std::optional<OperandId> scale_operand_id;
std::optional<OperandId> bias_operand_id;
std::vector<uint32_t> axes;
float epsilon = 1e-5;
};
LayerNormalizationAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
if (scale.has_value()) {
attributes.scale_operand_id =
builder.BuildInput("scale", scale->dimensions, scale->type);
}
if (bias.has_value()) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
builder.BuildLayerNormalization(input_operand_id, output_operand_id,
std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, LayerNormalizationTest) {
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {}},
.attributes = {.axes = {}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {}},
.expected = true}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.scale = OperandInfo{.type = OperandDataType::kFloat16,
.dimensions = {3, 4}},
.bias = OperandInfo{.type = OperandDataType::kFloat16,
.dimensions = {3, 4}},
.attributes = {.axes = {2, 3}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = true}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {}},
.attributes = {.axes = {0}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kInt64, .dimensions = {1}},
.attributes = {.axes = {}},
.output = {.type = OperandDataType::kInt64, .dimensions = {1}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.attributes = {.axes = {0, 0}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.attributes = {.axes = {2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {3, 4}},
.attributes = {.axes = {2, 3}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.scale = OperandInfo{.type = OperandDataType::kFloat16,
.dimensions = {2, 3}},
.attributes = {.axes = {2, 3}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.attributes = {.axes = {}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
LayerNormalizationTester{.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.attributes = {.axes = {}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
builder.BuildLayerNormalization(
input_operand_id, input_operand_id,
LayerNormalizationTester::LayerNormalizationAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId scale_operand_id =
builder.BuildInput("scale", {1, 2, 3, 4}, OperandDataType::kFloat32);
LayerNormalizationTester::LayerNormalizationAttributes attributes;
attributes.scale_operand_id = scale_operand_id;
attributes.axes = {0, 1, 2, 3};
builder.BuildLayerNormalization(input_operand_id, scale_operand_id,
std::move(attributes));
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
OperandId bias_operand_id =
builder.BuildInput("bias", {1, 2, 3, 4}, OperandDataType::kFloat32);
LayerNormalizationTester::LayerNormalizationAttributes attributes;
attributes.bias_operand_id = bias_operand_id;
attributes.axes = {0, 1, 2, 3};
builder.BuildLayerNormalization(input_operand_id, bias_operand_id,
std::move(attributes));
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct LstmTester {
struct LstmAttributes {
std::optional<OperandId> bias_operand_id;
std::optional<OperandId> recurrent_bias_operand_id;
std::optional<OperandId> peephole_weight_operand_id;
std::optional<OperandId> initial_hidden_state_operand_id;
std::optional<OperandId> initial_cell_state_operand_id;
bool return_sequence = false;
mojom::RecurrentNetworkDirection direction =
mojom::RecurrentNetworkDirection::kForward;
mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg;
std::vector<mojom::RecurrentNetworkActivation> activations = {
mojom::RecurrentNetworkActivation::kSigmoid,
mojom::RecurrentNetworkActivation::kTanh,
mojom::RecurrentNetworkActivation::kTanh};
};
OperandInfo input;
OperandInfo weight;
OperandInfo recurrent_weight;
uint32_t steps;
uint32_t hidden_size;
std::optional<OperandInfo> bias;
std::optional<OperandInfo> recurrent_bias;
std::optional<OperandInfo> peephole_weight;
std::optional<OperandInfo> initial_hidden_state;
std::optional<OperandInfo> initial_cell_state;
LstmAttributes attributes;
std::vector<OperandInfo> outputs;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId weight_operand_id =
builder.BuildInput("weight", weight.dimensions, weight.type);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
std::vector<OperandId> output_operand_ids;
output_operand_ids.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
output_operand_ids.push_back(
builder.BuildOutput(base::StringPrintf("output%zu", i),
outputs[i].dimensions, outputs[i].type));
}
if (bias.has_value()) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
if (recurrent_bias.has_value()) {
attributes.recurrent_bias_operand_id = builder.BuildInput(
"recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
}
if (peephole_weight.has_value()) {
attributes.peephole_weight_operand_id = builder.BuildInput(
"peepholeWeight", peephole_weight->dimensions, peephole_weight->type);
}
if (initial_hidden_state.has_value()) {
attributes.initial_hidden_state_operand_id = builder.BuildInput(
"initialHiddenState", initial_hidden_state->dimensions,
initial_hidden_state->type);
}
if (initial_cell_state.has_value()) {
attributes.initial_cell_state_operand_id =
builder.BuildInput("initialCellState", initial_cell_state->dimensions,
initial_cell_state->type);
}
builder.BuildLstm(input_operand_id, weight_operand_id,
recurrent_weight_operand_id,
std::move(output_operand_ids), steps, hidden_size,
std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, LstmTest) {
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t direction_count = 2;
LstmTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size,
input_size}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size}},
.recurrent_bias =
OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size}},
.peephole_weight =
OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 3 * hidden_size}},
.initial_hidden_state =
OperandInfo{
.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}},
.initial_cell_state =
OperandInfo{
.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}},
.attributes = {.return_sequence = true,
.direction = mojom::RecurrentNetworkDirection::kBoth},
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {steps, direction_count, batch_size,
hidden_size}}},
.expected = true}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t direction_count = 1;
LstmTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size, 1000}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}}},
.expected = false}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t direction_count = 1;
LstmTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {steps, batch_size, input_size}},
.weight = {.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size,
input_size}},
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {direction_count, 4 * hidden_size,
hidden_size}},
.steps = steps,
.hidden_size = hidden_size,
.outputs = {{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {direction_count, batch_size, 1000}}},
.expected = false}
.Test(*this);
}
{
uint32_t steps = 2;
uint32_t batch_size = 16;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t direction_count = 1;
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id = builder.BuildInput(
"input", {steps, batch_size, input_size}, OperandDataType::kFloat32);
OperandId weight_operand_id = builder.BuildInput(
"weight", {direction_count, 4 * hidden_size, input_size},
OperandDataType::kFloat32);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", {direction_count, 4 * hidden_size, hidden_size},
OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", {direction_count, batch_size, hidden_size},
OperandDataType::kFloat32);
builder.BuildLstm(input_operand_id, weight_operand_id,
recurrent_weight_operand_id,
{output_operand_id, recurrent_weight_operand_id}, steps,
hidden_size, LstmTester::LstmAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
uint32_t steps = 2;
uint32_t batch_size = 1;
uint32_t input_size = 3;
uint32_t hidden_size = 4;
uint32_t direction_count = 1;
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id = builder.BuildInput(
"input", {steps, batch_size, input_size}, OperandDataType::kFloat32);
OperandId weight_operand_id = builder.BuildInput(
"weight", {direction_count, 4 * hidden_size, input_size},
OperandDataType::kFloat32);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", {direction_count, 4 * hidden_size, hidden_size},
OperandDataType::kFloat32);
OperandId initial_cell_state_operand_id = builder.BuildInput(
"initialCellState", {direction_count, batch_size, hidden_size},
OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", {direction_count, batch_size, hidden_size},
OperandDataType::kFloat32);
builder.BuildLstm(
input_operand_id, weight_operand_id, recurrent_weight_operand_id,
{initial_cell_state_operand_id, output_operand_id}, steps, hidden_size,
LstmTester::LstmAttributes{.initial_cell_state_operand_id =
initial_cell_state_operand_id});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct LstmCellTester {
struct LstmCellAttributes {
std::optional<OperandId> bias_operand_id;
std::optional<OperandId> recurrent_bias_operand_id;
std::optional<OperandId> peephole_weight_operand_id;
mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg;
std::vector<mojom::RecurrentNetworkActivation> activations = {
mojom::RecurrentNetworkActivation::kSigmoid,
mojom::RecurrentNetworkActivation::kTanh,
mojom::RecurrentNetworkActivation::kTanh};
};
OperandInfo input;
OperandInfo weight;
OperandInfo recurrent_weight;
OperandInfo hidden_state;
OperandInfo cell_state;
uint32_t hidden_size;
std::optional<OperandInfo> bias;
std::optional<OperandInfo> recurrent_bias;
std::optional<OperandInfo> peephole_weight;
LstmCellAttributes attributes;
std::vector<OperandInfo> outputs;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId weight_operand_id =
builder.BuildInput("weight", weight.dimensions, weight.type);
OperandId recurrent_weight_operand_id = builder.BuildInput(
"recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
OperandId hidden_state_operand_id = builder.BuildInput(
"hiddenState", hidden_state.dimensions, hidden_state.type);
OperandId cell_state_operand_id =
builder.BuildInput("cellState", cell_state.dimensions, cell_state.type);
std::vector<OperandId> output_operand_ids;
output_operand_ids.reserve(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
output_operand_ids.push_back(
builder.BuildOutput(base::StringPrintf("output%zu", i),
outputs[i].dimensions, outputs[i].type));
}
if (bias.has_value()) {
attributes.bias_operand_id =
builder.BuildInput("bias", bias->dimensions, bias->type);
}
if (recurrent_bias.has_value()) {
attributes.recurrent_bias_operand_id = builder.BuildInput(
"recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
}
if (peephole_weight.has_value()) {
attributes.peephole_weight_operand_id = builder.BuildInput(
"peepholeWeight", peephole_weight->dimensions, peephole_weight->type);
}
builder.BuildLstmCell(input_operand_id, weight_operand_id,
recurrent_weight_operand_id, hidden_state_operand_id,
cell_state_operand_id, std::move(output_operand_ids),
hidden_size, std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, LstmCellTest) {
uint32_t batch_size = 15;
uint32_t input_size = 12;
uint32_t hidden_size = 20;
OperandInfo valid_input = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, input_size}};
OperandInfo valid_weight = {.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size, input_size}};
OperandInfo valid_recurrent_weight = {
.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size, hidden_size}};
OperandInfo valid_hidden_state = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size}};
OperandInfo valid_cell_state = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size}};
OperandInfo valid_bias = {.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size}};
OperandInfo valid_recurrent_bias = {.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size}};
OperandInfo valid_peephole_weight = {.type = OperandDataType::kFloat32,
.dimensions = {3 * hidden_size}};
std::vector<OperandInfo> valid_outputs = {
{.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size}},
{.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size}}};
{
LstmCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.bias = valid_bias,
.recurrent_bias = valid_recurrent_bias,
.peephole_weight = valid_peephole_weight,
.outputs = valid_outputs,
.expected = true}
.Test(*this);
}
{
LstmCellTester{.input = {.type = OperandDataType::kUint32,
.dimensions = {batch_size, input_size}},
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{.input = valid_input,
.weight = {.type = OperandDataType::kFloat16,
.dimensions = {4 * hidden_size, input_size}},
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = {.type = OperandDataType::kFloat32,
.dimensions = {4 * hidden_size}},
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, 1000}},
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = {.type = OperandDataType::kFloat32,
.dimensions = {batch_size, hidden_size, 1000}},
.hidden_size = hidden_size,
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.bias = OperandInfo{.type = OperandDataType::kUint32,
.dimensions = {4 * hidden_size}},
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.recurrent_bias = OperandInfo{.type = OperandDataType::kFloat32,
.dimensions = {1000}},
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{
.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.peephole_weight = OperandInfo{.type = OperandDataType::kInt64,
.dimensions = {3 * hidden_size}},
.outputs = valid_outputs,
.expected = false}
.Test(*this);
}
{
LstmCellTester{.input = valid_input,
.weight = valid_weight,
.recurrent_weight = valid_recurrent_weight,
.hidden_state = valid_hidden_state,
.cell_state = valid_cell_state,
.hidden_size = hidden_size,
.outputs = {{.type = OperandDataType::kInt8,
.dimensions = {batch_size, hidden_size}},
{.type = OperandDataType::kInt8,
.dimensions = {batch_size, hidden_size}}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id = builder.BuildInput(
"input", {batch_size, input_size}, OperandDataType::kFloat32);
OperandId weight_operand_id = builder.BuildInput(
"weight", {4 * hidden_size, input_size}, OperandDataType::kFloat32);
OperandId recurrent_weight_operand_id =
builder.BuildInput("recurrentWeight", {4 * hidden_size, hidden_size},
OperandDataType::kFloat32);
OperandId hidden_state_operand_id = builder.BuildInput(
"hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32);
OperandId cell_state_operand_id = builder.BuildInput(
"cellState", {batch_size, hidden_size}, OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", {batch_size, hidden_size}, OperandDataType::kFloat32);
builder.BuildLstmCell(input_operand_id, weight_operand_id,
recurrent_weight_operand_id, hidden_state_operand_id,
cell_state_operand_id,
{cell_state_operand_id, output_operand_id},
hidden_size, LstmTester::LstmAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct MatmulTester {
OperandInfo a;
OperandInfo b;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId a_operand_id = builder.BuildInput("a", a.dimensions, a.type);
OperandId b_operand_id = builder.BuildInput("b", b.dimensions, b.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildMatmul(a_operand_id, b_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, MatmulTest) {
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = true}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 3, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 2, 4}},
.expected = true}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 3, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 2, 2, 4}},
.expected = true}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kUint8, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kUint8, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.expected = false}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
MatmulTester{
.a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId a_operand_id =
builder.BuildInput("a", {2, 3}, OperandDataType::kFloat32);
OperandId b_operand_id =
builder.BuildInput("b", {3, 4}, OperandDataType::kFloat32);
builder.BuildMatmul(a_operand_id, b_operand_id, a_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct PadTester {
OperandInfo input;
std::vector<uint32_t> beginning_padding;
std::vector<uint32_t> ending_padding;
mojom::PaddingMode::Tag mode = mojom::PaddingMode::Tag::kConstant;
float value = 0;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildPad(input_operand_id, output_operand_id, beginning_padding,
ending_padding, mode, value);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, PadTest) {
{
PadTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.beginning_padding = {1, 2},
.ending_padding = {1, 2},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
.expected = true}
.Test(*this);
}
{
PadTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.beginning_padding = {1, 2},
.ending_padding = {1, 2},
.mode = mojom::PaddingMode::Tag::kEdge,
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
.expected = true}
.Test(*this);
}
{
PadTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.beginning_padding = {1, 2},
.ending_padding = {1, 2},
.value = 1,
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
.expected = true}
.Test(*this);
}
{
PadTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.beginning_padding = {1, 2},
.ending_padding = {1, 2},
.value = NAN,
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
.expected = true}
.Test(*this);
}
{
PadTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.beginning_padding = {1},
.ending_padding = {1, 2},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
.expected = false}
.Test(*this);
}
{
PadTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.beginning_padding = {1, 0},
.ending_padding = {1, 2, 0},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
builder.BuildPad(input_operand_id, input_operand_id, {1, 1}, {1, 1},
mojom::PaddingMode::Tag::kConstant, 0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct Pool2dTester {
OperandInfo input;
struct Pool2dAttributes {
std::vector<uint32_t> window_dimensions;
std::vector<uint32_t> padding = {0, 0, 0, 0};
std::vector<uint32_t> strides = {1, 1};
std::vector<uint32_t> dilations = {1, 1};
};
Pool2dAttributes attributes;
InputOperandLayout input_operand_layout = InputOperandLayout::kNchw;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
Test(test, mojom::Pool2d::Kind::kAveragePool2d);
Test(test, mojom::Pool2d::Kind::kL2Pool2d);
Test(test, mojom::Pool2d::Kind::kMaxPool2d);
}
void Test(WebNNGraphImplTest& test, mojom::Pool2d::Kind kind) {
auto context_properties = GetContextPropertiesForTesting();
context_properties.input_operand_layout = input_operand_layout;
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildPool2d(kind, input_operand_id, output_operand_id,
std::move(attributes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, Pool2dTest) {
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {1, 1}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.expected = true}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 5, 5}},
.attributes = {.window_dimensions = {2, 2}, .strides = {2, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 3, 3}},
.expected = true}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 3, 7, 7}},
.attributes = {.window_dimensions = {4, 4},
.padding = {1, 1, 1, 1},
.strides = {2, 2}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 3, 3, 3}},
.expected = true}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 7, 7}},
.attributes = {.window_dimensions = {4, 4},
.padding = {1, 1, 1, 1},
.strides = {2, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.expected = true}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 5, 5, 2}},
.attributes = {.window_dimensions = {3, 3}, .strides = {1, 1}},
.input_operand_layout = InputOperandLayout::kNhwc,
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 3, 3, 2}},
.expected = true}
.Test(*this);
}
{
Pool2dTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 5, 5}},
.attributes = {.window_dimensions = {5, 5},
.padding = {2, 2, 2, 2},
.strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 5, 5}},
.expected = false}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {0, 0}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.expected = false}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {1, 1}, .strides = {0, 0}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.expected = false}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {1, 1},
.strides = {1, 1},
.dilations = {0, 0}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.expected = false}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 1, 1}},
.expected = false}
.Test(*this);
}
{
Pool2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 3, 1, 1}},
.expected = false}
.Test(*this);
}
{
Pool2dTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 1, 1}},
.expected = false}
.Test(*this, mojom::Pool2d::Kind::kAveragePool2d);
}
{
Pool2dTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {1, 3, 4, 4}},
.attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
.output = {.type = OperandDataType::kInt8, .dimensions = {1, 3, 1, 1}},
.expected = false}
.Test(*this, mojom::Pool2d::Kind::kL2Pool2d);
}
}
struct PreluTester {
OperandInfo input;
OperandInfo slope;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId slope_operand_id =
builder.BuildInput("slope", slope.dimensions, slope.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildPrelu(input_operand_id, slope_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, PreluTest) {
{
PreluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
PreluTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 6}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId slope_operand_id =
builder.BuildInput("slope", {2, 3}, OperandDataType::kFloat32);
builder.BuildPrelu(input_operand_id, slope_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId output_operand_id =
builder.BuildOutput("output", {2, 3}, OperandDataType::kFloat32);
builder.BuildPrelu(input_operand_id, output_operand_id, output_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct QuantizeLinearTester {
OperandInfo input;
OperandInfo scale;
OperandInfo zero_point;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId scale_operand_id =
builder.BuildInput("scale", scale.dimensions, scale.type);
OperandId zero_point_operand_id = builder.BuildInput(
"zero_point", zero_point.dimensions, zero_point.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, QuantizeLinearTest) {
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 1}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 1, 1}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = true}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 5}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat16, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
.output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
.output = {.type = OperandDataType::kUint8, .dimensions = {3, 2, 5}},
.expected = false}
.Test(*this);
}
{
QuantizeLinearTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
.scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
.zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
.output = {.type = OperandDataType::kUint8, .dimensions = {5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId scale_operand_id =
builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
OperandId zero_point_operand_id =
builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId scale_operand_id =
builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
OperandId zero_point_operand_id =
builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, scale_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
OperandId scale_operand_id =
builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
OperandId zero_point_operand_id =
builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
zero_point_operand_id, zero_point_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ReduceTester {
mojom::Reduce::Kind kind;
OperandInfo input;
std::vector<uint32_t> axes;
bool keep_dimensions = false;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildReduce(kind, input_operand_id, output_operand_id, axes,
keep_dimensions);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ReduceTest) {
{
ReduceTester{.kind = mojom::Reduce::Kind::kL1,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {0, 2},
.keep_dimensions = true,
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 3, 1, 5}},
.expected = true}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kL1,
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4, 5}},
.axes = {0, 2},
.keep_dimensions = true,
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 1, 5}},
.expected = true}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kL2,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {2},
.keep_dimensions = false,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 5}},
.expected = true}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kMin,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {0, 1, 2, 3},
.output = {.type = OperandDataType::kFloat32, .dimensions = {}},
.expected = true}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kMin,
.input = {.type = OperandDataType::kInt64, .dimensions = {2, 3, 4, 5}},
.axes = {0, 1, 2, 3},
.output = {.type = OperandDataType::kInt64, .dimensions = {}},
.expected = true}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kSum,
.input = {.type = OperandDataType::kInt64, .dimensions = {2, 3, 4, 5}},
.axes = {0, 1, 2, 3},
.output = {.type = OperandDataType::kInt64, .dimensions = {}},
.expected = true}
.Test(*this);
}
{
ReduceTester{.kind = mojom::Reduce::Kind::kMin,
.input = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.axes = {},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.expected = true}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kMax,
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.axes = {0, 1, 2},
.keep_dimensions = false,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kMean,
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.axes = {1, 1},
.keep_dimensions = false,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kSum,
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.axes = {2},
.keep_dimensions = false,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kProduct,
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kLogSum,
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kFloat16, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kLogSum,
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kLogSumExp,
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kL2,
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kMean,
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kProduct,
.input = {.type = OperandDataType::kInt8, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt8, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kL1,
.input = {.type = OperandDataType::kUint8, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kUint8, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kSum,
.input = {.type = OperandDataType::kUint8, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kUint8, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kSumSquare,
.input = {.type = OperandDataType::kInt8, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt8, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
ReduceTester{
.kind = mojom::Reduce::Kind::kLogSum,
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.axes = {0},
.keep_dimensions = false,
.output = {.type = OperandDataType::kInt32, .dimensions = {3}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
builder.BuildReduce(mojom::Reduce::Kind::kSumSquare, input_operand_id,
input_operand_id, {0}, false);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ReluTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildRelu(input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ReluTest) {
{
ReluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
.expected = true}
.Test(*this);
}
{
ReluTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 5, 3, 7}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 5, 3, 7}},
.expected = true}
.Test(*this);
}
{
ReluTester{
.input = {.type = OperandDataType::kUint32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kUint32, .dimensions = {4, 2}},
.expected = false}
.Test(*this);
}
{
ReluTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
ReluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
}
struct Resample2dTester {
OperandInfo input;
struct Resample2dAttributes {
mojom::Resample2d::InterpolationMode mode =
mojom::Resample2d::InterpolationMode::kNearestNeighbor;
std::optional<std::vector<float>> scales;
std::vector<uint32_t> axes = {2, 3};
};
Resample2dAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildResample2d(input_operand_id, output_operand_id, attributes);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, Resample2dTest) {
{
Resample2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = true}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 4, 1}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{2, 2},
.axes = {1, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 4, 8, 1}},
.expected = true}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 4, 1}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{2, 2},
.axes = {1, 2}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 4, 8, 1}},
.expected = true}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 4, 1}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{2, 2.2},
.axes = {1, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 4, 8, 1}},
.expected = true}
.Test(*this);
}
{
Resample2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 1, 4, 8}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 2, 4}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 4, 8}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 2}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 4, 1}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{2, 2},
.axes = {1, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 5, 8, 1}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 34902, 23243}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{232433, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{0.02, 0.8}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 34902, 23243}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{20, 434324}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
.scales = std::vector<float>{0.7, 0.1}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.attributes{.scales = std::vector<float>{1.0, -2.0}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 4, 4}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.attributes = {.axes = {1, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 2, 8}},
.expected = true}
.Test(*this);
}
{
Resample2dTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.attributes = {.scales = std::vector<float>{2, 2}, .axes = {2, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 4, 8}},
.expected = false}
.Test(*this);
}
{
Resample2dTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.attributes = {.axes = {2, 3}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 4, 8}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {1, 1, 2, 4}, OperandDataType::kFloat32);
builder.BuildResample2d(input_operand_id, input_operand_id,
Resample2dTester::Resample2dAttributes{});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ReshapeTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildReshape(input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ReshapeTest) {
{
ReshapeTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {8}},
.expected = true}
.Test(*this);
}
{
ReshapeTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 2, 1}},
.output = {.type = OperandDataType::kInt32, .dimensions = {1, 6}},
.expected = true}
.Test(*this);
}
{
ReshapeTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.output = {.type = OperandDataType::kInt32, .dimensions = {3, 5}},
.expected = false}
.Test(*this);
}
{
ReshapeTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
static constexpr SupportedRanks kRankLimit = SupportedRanks::UpTo(4);
context_properties.data_type_limits.reshape_input.ranks.IntersectWith(
kRankLimit);
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
OperandId output_operand_id = builder.BuildOutput(
"output", {2, 1, 1, 1, 1}, OperandDataType::kFloat32);
builder.BuildReshape(input_operand_id, output_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ReverseTester {
OperandInfo input;
OperandInfo output;
std::vector<uint32_t> axes;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildReverse(input_operand_id, output_operand_id, std::move(axes));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ReverseTest) {
{
ReverseTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.axes = {0, 1, 2},
.expected = true}
.Test(*this);
}
{
ReverseTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.axes = {1, 1, 2},
.expected = false}
.Test(*this);
}
{
ReverseTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.axes = {4},
.expected = false}
.Test(*this);
}
{
ReverseTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}},
.axes = {0},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {3, 3}, OperandDataType::kFloat32);
builder.BuildReverse(input_operand_id, input_operand_id, {1});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ScatterElementsTester {
OperandInfo input;
OperandInfo indices;
OperandInfo updates;
OperandInfo output;
uint32_t axis = 0;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId indices_operand_id =
builder.BuildInput("indices", indices.dimensions, indices.type);
OperandId updates_operand_id =
builder.BuildInput("updates", updates.dimensions, updates.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildScatterElements(input_operand_id, indices_operand_id,
updates_operand_id, output_operand_id, axis);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ScatterElementsTest) {
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.axis = 0,
.expected = true}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 5}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {1, 2}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 5}},
.axis = 1,
.expected = true}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.axis = 2,
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
.updates = {.type = OperandDataType::kFloat16, .dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {}},
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3, 3}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 4}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.axis = 0,
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 2}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.axis = 1,
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterElementsTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {3, 3}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {3, 3}, OperandDataType::kFloat32);
OperandId indices_operand_id =
builder.BuildInput("indices", {2, 3}, OperandDataType::kUint32);
OperandId updates_operand_id =
builder.BuildInput("updates", {2, 3}, OperandDataType::kFloat32);
builder.BuildScatterElements(input_operand_id, indices_operand_id,
updates_operand_id, input_operand_id,
0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct ScatterNDTester {
OperandInfo input;
OperandInfo indices;
OperandInfo updates;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId indices_operand_id =
builder.BuildInput("indices", indices.dimensions, indices.type);
OperandId updates_operand_id =
builder.BuildInput("updates", updates.dimensions, updates.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildScatterND(input_operand_id, indices_operand_id,
updates_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ScatterNDTest) {
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.expected = true}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
.updates = {.type = OperandDataType::kFloat16, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 4}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.expected = false}
.Test(*this);
}
{
ScatterNDTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
.indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
.updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {4, 4, 4}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {4, 4, 4}, OperandDataType::kFloat32);
OperandId indices_operand_id =
builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32);
OperandId updates_operand_id =
builder.BuildInput("updates", {2, 4, 4}, OperandDataType::kFloat32);
builder.BuildScatterND(input_operand_id, indices_operand_id,
updates_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct SliceTester {
struct SliceAttributes {
std::vector<uint32_t> starts;
std::vector<uint32_t> sizes;
std::vector<uint32_t> strides;
};
OperandInfo input;
SliceAttributes attributes;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildSlice(input_operand_id, output_operand_id, attributes.starts,
attributes.sizes, attributes.strides);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, SliceTest) {
{
SliceTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.attributes = {.starts = {0, 0}, .sizes = {4, 4}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.expected = true}
.Test(*this);
}
{
SliceTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.attributes = {.starts = {0, 0}, .sizes = {2, 2}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.expected = true}
.Test(*this);
}
{
SliceTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.attributes = {.starts = {2, 2}, .sizes = {2, 2}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.expected = true}
.Test(*this);
}
{
SliceTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.attributes = {.starts = {1, 0}, .sizes = {1, 1}, .strides = {2, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.expected = false}
.Test(*this);
}
{
SliceTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.attributes = {.starts = {0, 0}, .sizes = {1, 1}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 1}},
.expected = false}
.Test(*this);
}
{
SliceTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.attributes = {.starts = {0}, .sizes = {4}, .strides = {1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.expected = false}
.Test(*this);
}
{
SliceTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {4, 4}},
.attributes = {.starts = {0, 0}, .sizes = {4, 4}, .strides = {1, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
.expected = false}
.Test(*this);
}
}
enum class FloatingPointUnaryKind {
kHardSwish,
kLeakyRelu,
kLinear,
kSigmoid,
kTanh
};
struct FloatingPointUnaryTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
Test(test, FloatingPointUnaryKind::kHardSwish);
Test(test, FloatingPointUnaryKind::kLeakyRelu);
Test(test, FloatingPointUnaryKind::kLinear);
Test(test, FloatingPointUnaryKind::kSigmoid);
Test(test, FloatingPointUnaryKind::kTanh);
}
void Test(WebNNGraphImplTest& test, FloatingPointUnaryKind kind) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
switch (kind) {
case FloatingPointUnaryKind::kHardSwish:
builder.BuildHardSwish(input_operand_id, output_operand_id);
break;
case FloatingPointUnaryKind::kLeakyRelu:
builder.BuildLeakyRelu(input_operand_id, output_operand_id,
1.0);
break;
case FloatingPointUnaryKind::kLinear:
builder.BuildLinear(input_operand_id, output_operand_id,
1.0, 0.0);
break;
case FloatingPointUnaryKind::kSigmoid:
builder.BuildSigmoid(input_operand_id, output_operand_id);
break;
case FloatingPointUnaryKind::kTanh:
builder.BuildTanh(input_operand_id, output_operand_id);
break;
}
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, FloatingPointUnaryTest) {
{
FloatingPointUnaryTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
.expected = true}
.Test(*this);
}
{
FloatingPointUnaryTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {2, 6, 4}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 6, 4}},
.expected = true}
.Test(*this);
}
{
FloatingPointUnaryTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
FloatingPointUnaryTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
FloatingPointUnaryTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
builder.BuildLeakyRelu(input_operand_id, input_operand_id,
1.0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
OperandId output_operand_id =
builder.BuildOutput("output", {2}, OperandDataType::kFloat32);
builder.BuildLeakyRelu(input_operand_id, output_operand_id,
NAN);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
builder.BuildLinear(input_operand_id, input_operand_id,
1.0, 0.0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
OperandId output_operand_id =
builder.BuildOutput("output", {2}, OperandDataType::kFloat32);
builder.BuildLinear(input_operand_id, output_operand_id,
NAN, 0.0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
OperandId output_operand_id =
builder.BuildOutput("output", {2}, OperandDataType::kFloat32);
builder.BuildLinear(input_operand_id, output_operand_id,
1.0, NAN);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
builder.BuildSigmoid(input_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {2}, OperandDataType::kFloat32);
builder.BuildTanh(input_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct SoftmaxTester {
OperandInfo input;
OperandInfo output;
uint32_t axis;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildSoftmax(input_operand_id, output_operand_id, axis);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, SoftmaxTest) {
{
SoftmaxTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.axis = 1,
.expected = true}
.Test(*this);
}
{
SoftmaxTester{
.input = {.type = OperandDataType::kFloat16, .dimensions = {1, 4}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {1, 4}},
.axis = 1,
.expected = true}
.Test(*this);
}
{
SoftmaxTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 4, 2}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 1, 4, 2}},
.axis = 3,
.expected = true}
.Test(*this);
}
{
SoftmaxTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.output = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
.axis = 1,
.expected = false}
.Test(*this);
}
{
SoftmaxTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.axis = 2,
.expected = false}
.Test(*this);
}
{
SoftmaxTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.axis = 1,
.expected = false}
.Test(*this);
}
{
SoftmaxTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
.axis = 1,
.expected = false}
.Test(*this);
}
}
struct SoftplusTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildSoftplus(input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, SoftplusTest) {
{
SoftplusTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.expected = true}
.Test(*this);
}
{
SoftplusTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
.expected = false}
.Test(*this);
}
{
SoftplusTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
SoftplusTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
builder.BuildSoftplus(input_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct SoftsignTester {
OperandInfo input;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildSoftsign(input_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, SoftsignTest) {
{
SoftsignTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = true}
.Test(*this);
}
{
SoftsignTester{
.input = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
.expected = false}
.Test(*this);
}
{
SoftsignTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
SoftsignTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
builder.BuildSoftsign(input_operand_id, input_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct SplitTester {
OperandInfo input;
std::vector<OperandInfo> outputs;
uint32_t axis = 0;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
std::vector<OperandId> output_operand_ids;
for (size_t i = 0; i < outputs.size(); ++i) {
output_operand_ids.push_back(
builder.BuildOutput("output" + base::NumberToString(i),
outputs[i].dimensions, outputs[i].type));
}
builder.BuildSplit(input_operand_id, output_operand_ids, axis);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, ValidateSplitTest) {
using OperandDataType::kFloat32;
{
SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}},
.outputs = {{.type = kFloat32, .dimensions = {1, 2}},
{.type = kFloat32, .dimensions = {1, 2}}},
.expected = true}
.Test(*this);
}
{
SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}},
.outputs = {{.type = kFloat32, .dimensions = {2, 1}},
{.type = kFloat32, .dimensions = {2, 1}}},
.axis = 1,
.expected = true}
.Test(*this);
}
{
SplitTester{
.input = {.type = kFloat32, .dimensions = {2, 2}},
.outputs = {{.type = kFloat32, .dimensions = {1, 2}},
{.type = OperandDataType::kFloat16, .dimensions = {1, 2}}},
.expected = false}
.Test(*this);
}
{
SplitTester{.input = {.type = kFloat32, .dimensions = {2, 6}},
.outputs = {{.type = kFloat32, .dimensions = {2, 1}},
{.type = kFloat32, .dimensions = {2, 2}},
{.type = kFloat32, .dimensions = {2, 2}}},
.axis = 1,
.expected = false}
.Test(*this);
}
{
SplitTester{.input = {.type = kFloat32, .dimensions = {2, 6}},
.outputs = {{.type = kFloat32, .dimensions = {2, 1}},
{.type = kFloat32, .dimensions = {2, 2}},
{.type = kFloat32, .dimensions = {2, 4}}},
.axis = 1,
.expected = false}
.Test(*this);
}
{
SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}},
.outputs = {{.type = kFloat32, .dimensions = {1, 2}},
{.type = kFloat32, .dimensions = {1, 2}}},
.axis = 2,
.expected = false}
.Test(*this);
}
{
SplitTester{.input = {.type = kFloat32, .dimensions = {4, 6}},
.outputs = {{.type = kFloat32, .dimensions = {1, 2}},
{.type = kFloat32, .dimensions = {2, 3}},
{.type = kFloat32, .dimensions = {1, 1}}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id = builder.BuildInput("input", {4, 6}, kFloat32);
builder.BuildSplit(input_operand_id, {input_operand_id}, 0);
builder.BuildSplit(input_operand_id,
{builder.BuildOutput("output", {4, 6}, kFloat32)}, 0);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct TileTester {
OperandInfo input;
std::vector<uint32_t> repetitions;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildTile(input_operand_id, output_operand_id,
std::move(repetitions));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, TileTest) {
{
TileTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.repetitions = {2, 3, 1, 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {2, 6, 3, 8}},
.expected = true}
.Test(*this);
}
{
TileTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.repetitions = {},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.expected = false}
.Test(*this);
}
{
TileTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.repetitions = {1, 1, 2, 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
TileTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.repetitions = {0, 1, 2, 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
TileTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 34902, 4}},
.repetitions = {1, 1, 232433, 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 2, 3}},
.expected = false}
.Test(*this);
}
{
TileTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.repetitions = {2, 1, 2, 3},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.expected = false}
.Test(*this);
}
{
TileTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.repetitions = {0, 1, 2, 3},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
builder.BuildTile(input_operand_id, input_operand_id,
std::vector<uint32_t>{1, 2});
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct TransposeTester {
OperandInfo input;
std::vector<uint32_t> permutation;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildTranspose(input_operand_id, output_operand_id,
std::move(permutation));
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, TransposeTest) {
{
TransposeTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.permutation = {2, 3, 1, 0},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {3, 4, 2, 1}},
.expected = true}
.Test(*this);
}
{
TransposeTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.permutation = {0, 1, 2, 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
TransposeTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.permutation = {0, 1, 2, 2},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 3}},
.expected = false}
.Test(*this);
}
{
TransposeTester{.input = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.permutation = {0, 1, 2, 4},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
{
TransposeTester{
.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.permutation = {0, 1, 2, 3},
.output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
.expected = false}
.Test(*this);
}
{
TransposeTester{.input = {.type = OperandDataType::kFloat32,
.dimensions = {1, 2, 3, 4}},
.permutation = {0, 1, 2, 3},
.output = {.type = OperandDataType::kFloat16,
.dimensions = {1, 2, 3, 4}},
.expected = false}
.Test(*this);
}
}
struct TriangularTester {
OperandInfo input;
bool upper = true;
int32_t diagonal = 0;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildTriangular(input_operand_id, output_operand_id, upper,
diagonal);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, TriangularTest) {
{
TriangularTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.upper = true,
.diagonal = 2,
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
.expected = true}
.Test(*this);
}
{
TriangularTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
.expected = false}
.Test(*this);
}
{
TriangularTester{
.input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
.expected = false}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
builder.BuildTriangular(input_operand_id, input_operand_id,
true, -1);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
struct WhereTester {
OperandInfo condition;
OperandInfo true_value;
OperandInfo false_value;
OperandInfo output;
bool expected;
void Test(WebNNGraphImplTest& test) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
test.BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId condition_operand_id =
builder.BuildInput("condition", condition.dimensions, condition.type);
OperandId true_value_operand_id = builder.BuildInput(
"true_value", true_value.dimensions, true_value.type);
OperandId false_value_operand_id = builder.BuildInput(
"false_value", false_value.dimensions, false_value.type);
OperandId output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildWhere(condition_operand_id, true_value_operand_id,
false_value_operand_id, output_operand_id);
EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
}
};
TEST_F(WebNNGraphImplTest, WhereTest) {
{
WhereTester{
.condition = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.false_value = {.type = OperandDataType::kFloat16,
.dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat16, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
.expected = false}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 1}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = false}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 1}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
.expected = true}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {1, 4}},
.true_value = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.expected = true}
.Test(*this);
}
{
WhereTester{
.condition = {.type = OperandDataType::kUint8, .dimensions = {2, 1, 4}},
.true_value = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {1, 4}},
.output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
.expected = true}
.Test(*this);
}
{
WhereTester{.condition = {.type = OperandDataType::kUint8,
.dimensions = {2, 3, 4, 5}},
.true_value = {.type = OperandDataType::kFloat32,
.dimensions = {3, 4, 5}},
.false_value = {.type = OperandDataType::kFloat32,
.dimensions = {4, 5}},
.output = {.type = OperandDataType::kFloat32,
.dimensions = {2, 3, 4, 5}},
.expected = true}
.Test(*this);
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId condition_operand_id =
builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8);
OperandId true_value_operand_id =
builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32);
OperandId false_value_operand_id =
builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32);
builder.BuildWhere(condition_operand_id, true_value_operand_id,
false_value_operand_id, condition_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId condition_operand_id =
builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8);
OperandId true_value_operand_id =
builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32);
OperandId false_value_operand_id =
builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32);
builder.BuildWhere(condition_operand_id, true_value_operand_id,
false_value_operand_id, true_value_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
{
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId condition_operand_id =
builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8);
OperandId true_value_operand_id =
builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32);
OperandId false_value_operand_id =
builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32);
builder.BuildWhere(condition_operand_id, true_value_operand_id,
false_value_operand_id, false_value_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}
TEST_F(WebNNGraphImplTest, ValidateDispatchTest) {
auto context_properties = GetContextPropertiesForTesting();
const OperandDataType kMojoDataType = OperandDataType::kUint8;
const OperandDataType kDataType = OperandDataType::kUint8;
const std::vector<uint32_t> kShape = {3, 5};
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
const OperandId lhs_operand_id =
builder.BuildInput("lhs", kShape, kMojoDataType);
const OperandId rhs_operand_id =
builder.BuildInput("rhs", kShape, kMojoDataType);
const OperandId output_1_operand_id =
builder.BuildOutput("output1", kShape, kMojoDataType);
builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
lhs_operand_id, rhs_operand_id,
output_1_operand_id);
const OperandId output_2_operand_id =
builder.BuildOutput("output2", kShape, kMojoDataType);
builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
lhs_operand_id, rhs_operand_id,
output_2_operand_id);
EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties));
test::WebNNTestEnvironment webnn_test_enviroment;
mojo::Remote<mojom::WebNNContextProvider> provider_remote;
webnn_test_enviroment.BindWebNNContextProvider(
provider_remote.BindNewPipeAndPassReceiver());
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_TRUE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["a_different_output_name"] =
CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["a_different_input_name"] =
CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["a_different_output_name"] =
CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, {2, 5});
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] =
CreateWebNNTensor(webnn_context, OperandDataType::kInt8, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, {3, 4});
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] =
CreateWebNNTensor(webnn_context, OperandDataType::kInt32, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, {2, 5});
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = {std::nullopt, inputs["lhs"].webnn_handle};
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_TRUE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = {std::nullopt,
outputs["output1"].webnn_handle};
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = {std::nullopt,
inputs["lhs"].webnn_handle};
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = {std::nullopt};
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
{
mojo::Remote<mojom::WebNNContext> webnn_context =
CreateWebNNContext(provider_remote);
base::flat_map<std::string, CreateTensorSuccess> inputs;
inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
base::flat_map<std::string, CreateTensorSuccess> outputs;
outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
outputs["output2"] = {std::nullopt};
EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
std::move(inputs), std::move(outputs)));
}
}
TEST_F(WebNNGraphImplTest, BuildMultipleInputsAppendingConstants) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId output_operand_id =
builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
OperandId input_a_operand_id =
builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
std::vector<float> constant_data = {5.0, 6.0, 7.0, 8.0};
OperandId constant_a_operand_id = builder.BuildConstant(
{2, 2}, OperandDataType::kFloat32,
base::as_byte_span(base::allow_nonunique_obj, constant_data));
OperandId intermediate_1_operand_id =
builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
builder.BuildGemm(input_a_operand_id, constant_a_operand_id,
intermediate_1_operand_id, GemmTester::GemmAttributes());
OperandId input_b_operand_id =
builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
OperandId constant_b_operand_id = builder.BuildConstant(
{2, 2}, OperandDataType::kFloat32,
base::as_byte_span(base::allow_nonunique_obj, constant_data));
OperandId intermediate_2_operand_id =
builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
builder.BuildGemm(input_b_operand_id, constant_b_operand_id,
intermediate_2_operand_id, GemmTester::GemmAttributes());
builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
output_operand_id, GemmTester::GemmAttributes());
EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties));
}
TEST_F(WebNNGraphImplTest, BuildMultipleConstantsAppendingInputs) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId output_operand_id =
builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
std::vector<float> constant_data = {5.0, 6.0, 7.0, 8.0};
OperandId constant_a_operand_id = builder.BuildConstant(
{2, 2}, OperandDataType::kFloat32,
base::as_byte_span(base::allow_nonunique_obj, constant_data));
OperandId input_a_operand_id =
builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
OperandId intermediate_1_operand_id =
builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
builder.BuildGemm(constant_a_operand_id, input_a_operand_id,
intermediate_1_operand_id, GemmTester::GemmAttributes());
OperandId input_b_operand_id =
builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
OperandId constant_b_operand_id = builder.BuildConstant(
{2, 2}, OperandDataType::kFloat32,
base::as_byte_span(base::allow_nonunique_obj, constant_data));
OperandId intermediate_2_operand_id =
builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
builder.BuildGemm(constant_b_operand_id, input_b_operand_id,
intermediate_2_operand_id, GemmTester::GemmAttributes());
builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
output_operand_id, GemmTester::GemmAttributes());
EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties));
}
TEST_F(WebNNGraphImplTest, BuildOperationWithNonexistentInputs) {
auto context_properties = GetContextPropertiesForTesting();
mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
BindNewGraphBuilderRemote();
GraphInfoBuilder builder(remote);
OperandId input_operand_id =
builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
OperandId intermediate_operand_id =
builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
OperandId output_operand_id =
builder.BuildOutput("output", {2, 2}, OperandDataType::kUint8);
builder.BuildRelu(intermediate_operand_id, output_operand_id);
builder.BuildRelu(input_operand_id, intermediate_operand_id);
EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}
}