910e62b5创建于 1月15日历史提交
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "services/webnn/webnn_graph_impl.h"

#include <cmath>
#include <limits>
#include <memory>

#include "base/containers/contains.h"
#include "base/memory/weak_ptr.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "base/test/run_until.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_associated_receiver.h"
#include "mojo/public/cpp/system/functions.h"
#include "services/webnn/error.h"
#include "services/webnn/public/cpp/ml_tensor_usage.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/supported_data_types.h"
#include "services/webnn/public/cpp/webnn_errors.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_device.mojom-data-view.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/scoped_sequence.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_tensor_impl.h"
#include "services/webnn/webnn_test_environment.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace webnn {

namespace {

// A fake WebNNGraph Mojo interface implementation that binds a pipe for
// computing graph message.
class FakeWebNNGraphImpl final : public WebNNGraphImpl {
 public:
  FakeWebNNGraphImpl(
      mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
      base::WeakPtr<WebNNContextImpl> context,
      ComputeResourceInfo compute_resource_info)
      : WebNNGraphImpl(std::move(receiver),
                       std::move(context),
                       std::move(compute_resource_info),
                       /*devices=*/{}) {}

  static void CreateAndBuild(
      mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
      base::WeakPtr<WebNNContextImpl> context,
      const mojom::GraphInfo& graph_info,
      ComputeResourceInfo compute_resource_info,
      WebNNContextImpl::CreateGraphImplCallback callback) {
    std::move(callback).Run(base::MakeRefCounted<FakeWebNNGraphImpl>(
        std::move(receiver), std::move(context),
        std::move(compute_resource_info)));
  }

 private:
  ~FakeWebNNGraphImpl() override = default;

  // Return nothing for testing the validation of inputs and outputs in
  // `WebNNGraphImpl::Dispatch()` function.
  void DispatchImpl(
      base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_inputs,
      base::flat_map<std::string, scoped_refptr<WebNNTensorImpl>> named_outputs)
      override {}
};

// A fake WebNNTensor Mojo interface implementation that binds a pipe for
// tensor creation message.
class FakeWebNNTensorImpl final : public WebNNTensorImpl {
 public:
  FakeWebNNTensorImpl(
      mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
      base::WeakPtr<WebNNContextImpl> context,
      mojom::TensorInfoPtr tensor_info)
      : WebNNTensorImpl(std::move(receiver),
                        std::move(context),
                        std::move(tensor_info)) {}

 private:
  ~FakeWebNNTensorImpl() override = default;

  // Read/write nothing for testing the validation of inputs and outputs in
  // `WebNNGraphImpl::Dispatch()` function.
  void ReadTensorImpl(ReadTensorCallback callback) override {}
  void WriteTensorImpl(mojo_base::BigBuffer src_buffer) override {}
  // Interop is not required by tests.
  bool ImportTensorImpl() override { return false; }
  void ExportTensorImpl(ScopedAccessPtr access,
                        ExportTensorCallback callback) override {}
};

// A fake WebNNContext Mojo interface implementation that binds a pipe for
// creating graph message.
class FakeWebNNContextImpl final : public WebNNContextImpl {
 public:
  FakeWebNNContextImpl(
      mojo::PendingReceiver<mojom::WebNNContext> receiver,
      base::WeakPtr<WebNNContextProviderImpl> context_provider,
      gpu::CommandBufferId command_buffer_id,
      std::unique_ptr<ScopedSequence> sequence,
      scoped_refptr<gpu::MemoryTracker> memory_tracker,
      scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
      gpu::SharedImageManager* shared_image_manager,
      scoped_refptr<base::SingleThreadTaskRunner> main_task_runner)
      : WebNNContextImpl(std::move(receiver),
                         std::move(context_provider),
                         GetContextPropertiesForTesting(),
                         mojom::CreateContextOptions::New(),
                         mojo::ScopedDataPipeConsumerHandle(),
                         mojo::ScopedDataPipeProducerHandle(),
                         command_buffer_id,
                         std::move(sequence),
                         std::move(memory_tracker),
                         std::move(owning_task_runner),
                         shared_image_manager,
                         std::move(main_task_runner)) {}

  // WebNNContextImpl:
  base::WeakPtr<WebNNContextImpl> AsWeakPtr() override {
    DCHECK_CALLED_ON_VALID_SEQUENCE(gpu_sequence_checker_);
    return weak_factory_.GetWeakPtr();
  }

 private:
  ~FakeWebNNContextImpl() override = default;

  void CreateGraphImpl(
      mojo::PendingAssociatedReceiver<mojom::WebNNGraph> receiver,
      mojom::GraphInfoPtr graph_info,
      WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
      base::flat_map<
          OperandId,
          std::unique_ptr<WebNNConstantOperand>> /*constant_operands*/,
      base::flat_map<OperandId, WebNNTensorImpl*> /*constant_tensor_operands*/,
      CreateGraphImplCallback callback) override {
    FakeWebNNGraphImpl::CreateAndBuild(
        std::move(receiver), AsWeakPtr(), *graph_info,
        std::move(compute_resource_info), std::move(callback));
  }

  base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
  CreateTensorImpl(mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
                   mojom::TensorInfoPtr tensor_info) override {
    return base::MakeRefCounted<FakeWebNNTensorImpl>(
        std::move(receiver), AsWeakPtr(), std::move(tensor_info));
  }

  base::expected<scoped_refptr<WebNNTensorImpl>, mojom::ErrorPtr>
  CreateTensorFromSharedImageImpl(
      mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
      mojom::TensorInfoPtr tensor_info,
      WebNNTensorImpl::RepresentationPtr representation) override {
    return base::unexpected(mojom::Error::New(
        mojom::Error::Code::kNotSupportedError, "Not implemented"));
  }

  base::WeakPtrFactory<FakeWebNNContextImpl> weak_factory_{this};
};

// Helper class to create the FakeWebNNContext that is intended to test
// the graph validation steps and computation resources.
class FakeWebNNBackend : public WebNNContextProviderImpl::BackendForTesting {
 public:
  std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter> CreateWebNNContext(
      base::WeakPtr<WebNNContextProviderImpl> context_provider_impl,
      mojom::CreateContextOptionsPtr options,
      gpu::CommandBufferId command_buffer_id,
      std::unique_ptr<ScopedSequence> sequence,
      scoped_refptr<gpu::MemoryTracker> memory_tracker,
      scoped_refptr<base::SingleThreadTaskRunner> owning_task_runner,
      gpu::SharedImageManager* shared_image_manager,
      scoped_refptr<base::SingleThreadTaskRunner> main_task_runner,
      mojom::WebNNContextProvider::CreateWebNNContextCallback callback)
      override {
    mojo::PendingRemote<mojom::WebNNContext> remote;
    auto task_runner = owning_task_runner;
    std::unique_ptr<WebNNContextImpl, OnTaskRunnerDeleter> context_impl(
        new FakeWebNNContextImpl(
            remote.InitWithNewPipeAndPassReceiver(),
            std::move(context_provider_impl), command_buffer_id,
            std::move(sequence), std::move(memory_tracker),
            std::move(owning_task_runner), shared_image_manager,
            std::move(main_task_runner)),
        OnTaskRunnerDeleter(std::move(task_runner)));
    ContextProperties context_properties = context_impl->properties();
    // The receiver bound to FakeWebNNContext.
    auto success = mojom::CreateContextSuccess::New(
        std::move(remote), std::move(context_properties),
        context_impl->handle(), mojo::ScopedDataPipeProducerHandle(),
        mojo::ScopedDataPipeConsumerHandle());
    std::move(callback).Run(
        mojom::CreateContextResult::NewSuccess(std::move(success)));
    return context_impl;
  }
};

struct CreateTensorSuccess {
  std::optional<mojo::AssociatedRemote<mojom::WebNNTensor>> webnn_tensor;
  blink::WebNNTensorToken webnn_handle;
};

CreateTensorSuccess CreateWebNNTensor(
    mojo::Remote<mojom::WebNNContext>& webnn_context,
    OperandDataType data_type,
    std::vector<uint32_t> shape) {
  base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future;
  webnn_context->CreateTensor(
      mojom::TensorInfo::New(
          OperandDescriptor::UnsafeCreateForTesting(data_type, shape),
          MLTensorUsage()),
      mojo_base::BigBuffer(0), create_tensor_future.GetCallback());
  mojom::CreateTensorResultPtr create_tensor_result =
      create_tensor_future.Take();
  mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor;
  webnn_tensor.Bind(
      std::move(create_tensor_result->get_success()->tensor_remote));
  return CreateTensorSuccess{
      std::move(webnn_tensor),
      std::move(create_tensor_result->get_success()->tensor_handle)};
}

mojo::Remote<mojom::WebNNContext> CreateWebNNContext(
    mojo::Remote<mojom::WebNNContextProvider>& webnn_context_provider) {
  base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
  webnn_context_provider->CreateWebNNContext(
      mojom::CreateContextOptions::New(), create_context_future.GetCallback());
  mojom::CreateContextResultPtr create_context_result =
      create_context_future.Take();
  mojo::Remote<mojom::WebNNContext> webnn_context;
  webnn_context.Bind(
      std::move(create_context_result->get_success()->context_remote));
  return webnn_context;
}

// Converts inputs and outputs to MLTensor then dispatches them.
bool ValidateDispatch(
    mojo::Remote<mojom::WebNNContext>& webnn_context,
    mojom::GraphInfoPtr graph_info,
    base::flat_map<std::string, CreateTensorSuccess> inputs,
    base::flat_map<std::string, CreateTensorSuccess> outputs) {
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> graph_builder_remote;
  webnn_context->CreateGraphBuilder(
      graph_builder_remote.BindNewEndpointAndPassReceiver());

  // Creates WebNN Graph mojo interface with the graph information which is
  // validated before compiling.
  base::test::TestFuture<
      base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
      create_graph_future;
  graph_builder_remote->CreateGraph(std::move(graph_info),
                                    create_graph_future.GetCallback());
  base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>
      create_graph_result = create_graph_future.Take();
  mojo::AssociatedRemote<mojom::WebNNGraph> webnn_graph;
  webnn_graph.Bind(std::move(create_graph_result.value()->graph_remote));

  // Validate the inputs in the `Dispatch` function.
  bool valid = true;
  // Set up the error handler for bad mojo messages.
  mojo::SetDefaultProcessErrorHandler(
      base::BindLambdaForTesting([&](const std::string& error_message) {
        EXPECT_EQ(error_message, kBadMessageInvalidTensor);
        valid = false;
      }));

  // Assign tensors for the inputs.
  base::flat_map<std::string, blink::WebNNTensorToken> dispatch_inputs;
  for (const auto& [name, tensor_info] : inputs) {
    dispatch_inputs.emplace(name, tensor_info.webnn_handle);
  }

  // Assign tensors for the outputs.
  base::flat_map<std::string, blink::WebNNTensorToken> dispatch_outputs;
  for (const auto& [name, tensor_info] : outputs) {
    dispatch_outputs.emplace(name, tensor_info.webnn_handle);
  }

  // Ensure CreateTensor messages have a chance to finish before calling
  // Dispatch().
  webnn_context.FlushForTesting();
  webnn_graph->Dispatch(dispatch_inputs, dispatch_outputs);

  // Ensure Dispatch message has a chance to finish before removing the error
  // handler.
  webnn_graph.FlushForTesting();
  mojo::SetDefaultProcessErrorHandler(base::NullCallback());
  return valid;
}

OperandDataType kAllOperandDataTypes[] = {
    OperandDataType::kFloat32, OperandDataType::kFloat16,
    OperandDataType::kInt32,   OperandDataType::kInt8,
    OperandDataType::kUint8,
};

}  // namespace

class WebNNGraphImplTest : public testing::Test {
 public:
  WebNNGraphImplTest(const WebNNGraphImplTest&) = delete;
  WebNNGraphImplTest& operator=(const WebNNGraphImplTest&) = delete;

  void SetUp() override {
    WebNNContextProviderImpl::SetBackendForTesting(&backend_for_testing_);

    webnn_test_environment_.BindWebNNContextProvider(
        provider_remote_.BindNewPipeAndPassReceiver());

    base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
    provider_remote_->CreateWebNNContext(mojom::CreateContextOptions::New(),
                                         create_context_future.GetCallback());
    mojom::CreateContextResultPtr create_context_result =
        create_context_future.Take();
    webnn_context_.Bind(
        std::move(create_context_result->get_success()->context_remote));
  }

  void TearDown() override {
    WebNNContextProviderImpl::SetBackendForTesting(nullptr);
  }

  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote() {
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote;
    webnn_context_->CreateGraphBuilder(remote.BindNewEndpointAndPassReceiver());
    return remote;
  }

 protected:
  WebNNGraphImplTest()
      : scoped_feature_list_(
            webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}
  ~WebNNGraphImplTest() override = default;

 private:
  base::test::ScopedFeatureList scoped_feature_list_;
  base::test::TaskEnvironment task_environment_;

  FakeWebNNBackend backend_for_testing_;

  test::WebNNTestEnvironment webnn_test_environment_;
  mojo::Remote<mojom::WebNNContextProvider> provider_remote_;
  mojo::Remote<mojom::WebNNContext> webnn_context_;
};

struct OperandInfo {
  OperandDataType type;
  std::vector<uint32_t> dimensions;
};

struct ArgMinMaxTester {
  mojom::ArgMinMax::Kind kind;
  OperandInfo input;
  uint32_t axis;
  bool keep_dimensions = false;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildArgMinMax(kind, input_operand_id, output_operand_id, axis,
                           keep_dimensions);

    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ArgMinMaxTest) {
  const auto ArgMinMaxKinds = {mojom::ArgMinMax_Kind::kMin,
                               mojom::ArgMinMax_Kind::kMax};
  for (const auto kind : ArgMinMaxKinds) {
    {
      // Test argMinMax operator with axis = 0 and keep_dimensions = true.
      ArgMinMaxTester{.kind = kind,
                      .input = {.type = OperandDataType::kFloat32,
                                .dimensions = {2, 3, 4, 5}},
                      .axis = 0,
                      .keep_dimensions = true,
                      .output = {.type = OperandDataType::kInt32,
                                 .dimensions = {1, 3, 4, 5}},
                      .expected = true}
          .Test(*this);
    }
    {
      // Test argMinMax operator with axis = 1 and keep_dimensions = false.
      ArgMinMaxTester{
          .kind = kind,
          .input = {.type = OperandDataType::kFloat16,
                    .dimensions = {2, 3, 4, 5}},
          .axis = 1,
          .keep_dimensions = false,
          .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4, 5}},
          .expected = true}
          .Test(*this);
    }
    {
      // Test the invalid graph when axis is greater than or equal to input
      // rank.
      ArgMinMaxTester{.kind = kind,
                      .input = {.type = OperandDataType::kFloat32,
                                .dimensions = {2, 3, 4, 5}},
                      .axis = 4,
                      .keep_dimensions = true,
                      .output = {.type = OperandDataType::kInt32,
                                 .dimensions = {2, 3, 4, 1}},
                      .expected = false}
          .Test(*this);
    }
    {
      // Test the invalid graph when the output data type is not support.
      ArgMinMaxTester{.kind = kind,
                      .input = {.type = OperandDataType::kFloat32,
                                .dimensions = {2, 3, 4, 5}},
                      .axis = 0,
                      .keep_dimensions = true,
                      .output = {.type = OperandDataType::kFloat32,
                                 .dimensions = {1, 3, 4, 5}},
                      .expected = false}
          .Test(*this);
    }
    {
      // Test the invalid graph when the output shape is incorrect.
      ArgMinMaxTester{.kind = kind,
                      .input = {.type = OperandDataType::kFloat32,
                                .dimensions = {2, 3, 4, 5}},
                      .axis = 0,
                      .keep_dimensions = false,
                      .output = {.type = OperandDataType::kInt32,
                                 .dimensions = {1, 3, 4, 5}},
                      .expected = false}
          .Test(*this);
    }
    {
      // Test the invalid graph when the input and output are same operand.
      auto context_properties = GetContextPropertiesForTesting();
      mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
          BindNewGraphBuilderRemote();
      GraphInfoBuilder builder(remote);
      OperandId input_operand_id =
          builder.BuildInput("input", {2, 3, 4, 5}, OperandDataType::kInt32);
      builder.BuildArgMinMax(kind, input_operand_id, input_operand_id,
                             /*axis=*/0,
                             /*keep_dimensions=*/true);
      EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
    }
  }
}

struct ClampTester {
  OperandInfo input;
  struct ClampAttributes {
    float min_value;
    float max_value;
  };
  ClampAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildClamp(input_operand_id, output_operand_id,
                       attributes.min_value, attributes.max_value);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ClampTest) {
  {
    // Test clamp operator with both the minimum and maximum values.
    ClampTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 4}},
        .attributes = {.min_value = 0.0, .max_value = 6.0},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test clamp operator with the min value is infinite.
    ClampTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .attributes = {.min_value = static_cast<float>(-1.0 / 0.0),
                       .max_value = 3.0},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test clamp operator with the max value is infinite.
    ClampTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .attributes = {.min_value = 0.0,
                       .max_value = static_cast<float>(1.0 / 0.0)},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test clamp operator when max value = 0 and min value = 0.
    ClampTester{.input = {.type = OperandDataType::kFloat32,
                          .dimensions = {1, 2, 2, 7}},
                .output = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 2, 2, 7}},
                .expected = true}
        .Test(*this);
  }
  {
    // Test clamp operator when the min value is NAN.
    ClampTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .attributes = {.min_value = NAN, .max_value = 3.0},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test clamp operator when the max value is NAN.
    ClampTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .attributes = {.min_value = -3.0, .max_value = NAN},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the max value is less than the min value.
    ClampTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .attributes = {.min_value = 7.0, .max_value = 3.0},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    ClampTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    ClampTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
                .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
                .expected = false}
        .Test(*this);
  }
}

struct HardSigmoidTester {
  OperandInfo input;
  std::optional<float> alpha;
  std::optional<float> beta;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildHardSigmoid(input_operand_id, output_operand_id, alpha, beta);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, HardSigmoidTest) {
  {
    // Test hardSigmoid operator with default alpha and beta values.
    HardSigmoidTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the alpha value is NAN.
    HardSigmoidTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .alpha = NAN,
        .beta = 0.5,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the beta value is NAN.
    HardSigmoidTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
        .alpha = 1.0,
        .beta = NAN,
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    HardSigmoidTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    HardSigmoidTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
}

struct BatchNormalizationTester {
  OperandInfo input;
  OperandInfo mean;
  OperandInfo variance;
  std::optional<OperandInfo> scale;
  std::optional<OperandInfo> bias;
  struct BatchNormalizationAttributes {
    std::optional<OperandId> scale_operand_id;
    std::optional<OperandId> bias_operand_id;
    uint32_t axis = 1;
    float epsilon = 1e-5;
  };
  BatchNormalizationAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId mean_operand_id =
        builder.BuildInput("mean", mean.dimensions, mean.type);
    OperandId variance_operand_id =
        builder.BuildInput("variance", variance.dimensions, variance.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    if (scale) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    builder.BuildBatchNormalization(input_operand_id, mean_operand_id,
                                    variance_operand_id, output_operand_id,
                                    std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, BatchNormalizationTest) {
  {
    // Test building batchNormalization with default option.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building batchNormalization with axis = 3.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {3}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {3}},
        .attributes = {.axis = 3},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building batchNormalization with setting optional bias and scale.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .scale =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building batchNormalization when input data type and mean data
    // type mismatched.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when the size of mean is not equal to
    // the size of the input dimension denoted by axis.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {3}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when input data type and variance data
    // type mismatched.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when the size of variance is not equal
    // to the size of the input dimension denoted by axis.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when input data is not floating point
    // type.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when axis is out of range [0, N-1].
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {3}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {3}},
        .attributes = {.axis = 4},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test batchNormalization when input data type and scale data type
    // mismatched.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .scale =
            OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when the size of scale is not equal
    // to the size of the input dimension denoted by axis.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .scale =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test batchNormalization when input data type and bias data type
    // mismatched.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .bias = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building batchNormalization when the size of bias is not equal
    // to the size of the input dimension denoted by axis.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output type is not the same as input type.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shape is not the same as input shape.
    BatchNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .mean = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .variance = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId mean_operand_id =
        builder.BuildInput("mean", {2}, OperandDataType::kFloat32);
    OperandId variance_operand_id =
        builder.BuildInput("variance", {2}, OperandDataType::kFloat32);
    builder.BuildBatchNormalization(
        input_operand_id, mean_operand_id, variance_operand_id,
        input_operand_id,
        BatchNormalizationTester::BatchNormalizationAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for mean operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId mean_operand_id =
        builder.BuildInput("mean", {2}, OperandDataType::kFloat32);
    OperandId variance_operand_id =
        builder.BuildInput("variance", {2}, OperandDataType::kFloat32);
    builder.BuildBatchNormalization(
        input_operand_id, mean_operand_id, variance_operand_id, mean_operand_id,
        BatchNormalizationTester::BatchNormalizationAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for variance operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId mean_operand_id =
        builder.BuildInput("mean", {2}, OperandDataType::kFloat32);
    OperandId variance_operand_id =
        builder.BuildInput("variance", {2}, OperandDataType::kFloat32);
    builder.BuildBatchNormalization(
        input_operand_id, mean_operand_id, variance_operand_id,
        variance_operand_id,
        BatchNormalizationTester::BatchNormalizationAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ConcatTester {
  std::vector<OperandInfo> inputs;
  uint32_t axis;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    std::vector<OperandId> input_operand_ids;
    input_operand_ids.reserve(inputs.size());
    for (size_t i = 0; i < inputs.size(); ++i) {
      input_operand_ids.push_back(
          builder.BuildInput(base::StringPrintf("input%zu", i),
                             inputs[i].dimensions, inputs[i].type));
    }
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildConcat(std::move(input_operand_ids), output_operand_id, axis);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ConcatTest) {
  {
    // Test concat operator with three inputs.
    ConcatTester{
        .inputs =
            {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5, 6}},
             {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5, 6}},
             {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5, 6}}},
        .axis = 1,
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 6, 5, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test concat operator when the input is the same as output.
    ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
                             .dimensions = {3, 1, 5, 6}}},
                 .axis = 1,
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {3, 1, 5, 6}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test concat operator with empty inputs.
    ConcatTester{.inputs = {},
                 .axis = 0,
                 .output = {.type = OperandDataType::kInt32, .dimensions = {1}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test concat operator when the inputs' datatypes don't match each
    // other.
    ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
                             .dimensions = {3, 1, 5, 6}},
                            {.type = OperandDataType::kInt32,
                             .dimensions = {3, 2, 5, 6}}},
                 .axis = 1,
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {3, 3, 5, 6}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test concat operator when the inputs can not be concatenated.
    ConcatTester{
        .inputs = {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5}},
                   {.type = OperandDataType::kFloat32,
                    .dimensions = {3, 2, 5, 6}}},
        .axis = 1,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test concat operator when the axis is equal to or greater than the
    // size of dimension.
    ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
                             .dimensions = {3, 1, 5, 6}},
                            {.type = OperandDataType::kFloat32,
                             .dimensions = {3, 1, 5, 6}}},
                 .axis = 4,
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {3, 1, 5, 12}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test concat operator when the inputs have other axes with different
    // sizes except on the axis.
    ConcatTester{.inputs = {{.type = OperandDataType::kFloat32,
                             .dimensions = {3, 1, 5, 6}},
                            {.type = OperandDataType::kFloat32,
                             .dimensions = {3, 1, 5, 1}}},
                 .axis = 1,
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {3, 2, 5, 7}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test concat operator when the output datatype doesn't match the
    // inputs' datatypes.
    ConcatTester{
        .inputs = {{.type = OperandDataType::kFloat32,
                    .dimensions = {3, 1, 5, 6}},
                   {.type = OperandDataType::kFloat32,
                    .dimensions = {3, 2, 5, 6}}},
        .axis = 1,
        .output = {.type = OperandDataType::kInt32, .dimensions = {3, 3, 5, 6}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test concat operator when the output dimension is incorrect.
    ConcatTester{
        .inputs = {{.type = OperandDataType::kFloat32, .dimensions = {3, 1, 2}},
                   {.type = OperandDataType::kFloat32,
                    .dimensions = {1, 1, 2}}},
        .axis = 0,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {5, 1, 2}},
        .expected = false}
        .Test(*this);
  }
}

struct Conv2dTester {
  mojom::Conv2d::Kind type;
  OperandInfo input;
  OperandInfo filter;
  struct Conv2dAttributes {
    std::vector<uint32_t> padding = {0, 0, 0, 0};
    std::vector<uint32_t> strides = {1, 1};
    std::vector<uint32_t> dilations = {1, 1};
    uint32_t groups = 1;
    std::optional<OperandInfo> bias;
  };
  Conv2dAttributes attributes;
  InputOperandLayout input_operand_layout = InputOperandLayout::kNchw;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    // Override the default input layout to exercise all the validation cases.
    context_properties.input_operand_layout = input_operand_layout;

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId filter_operand_id =
        builder.BuildInput("filter", filter.dimensions, filter.type);

    std::optional<OperandId> bias_operand_id;
    if (attributes.bias) {
      bias_operand_id = builder.BuildInput("bias", attributes.bias->dimensions,
                                           attributes.bias->type);
    }

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildConv2d(type, input_operand_id, filter_operand_id,
                        output_operand_id, std::move(attributes),
                        bias_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, Conv2dTest) {
  {
    // Test conv2d with default attributes.
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test conv2d for same upper or lower padding.
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat16,
                           .dimensions = {1, 1, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 1, 3, 3}},
                 .attributes = {.padding = {1, 1, 1, 1}},
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 1, 5, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test conv2d with strides=2 and padding=1.
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat16,
                           .dimensions = {1, 1, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 1, 3, 3}},
                 .attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}},
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 1, 3, 3}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test depthwise conv2d by setting groups to input channels.
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat16,
                           .dimensions = {1, 4, 2, 2}},
                 .filter = {.type = OperandDataType::kFloat16,
                            .dimensions = {4, 1, 2, 2}},
                 .attributes = {.groups = 4},
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 4, 1, 1}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test conv2d with inputLayout="nchw" and filterLayout="oihw".
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat16,
                           .dimensions = {1, 2, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 2, 3, 3}},
                 .input_operand_layout = InputOperandLayout::kNchw,
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 1, 3, 3}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is not a 4-D tensor.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 5, 5}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input data type is not floating point.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 3, 3}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the filter is not a 4-D tensor.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the filter type doesn't match the input
    // type.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the bias type doesn't match input type.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .attributes = {.bias = OperandInfo{.type = OperandDataType::kInt32,
                                           .dimensions = {1}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the bias shape is not equal to
    // [output_channels].
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .attributes = {.bias = OperandInfo{.type = OperandDataType::kFloat32,
                                           .dimensions = {2}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the number of filter input channels
    // doesn't match the result of input channels divided by groups
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .attributes = {.groups = 3},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 2, 1, 1}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    Conv2dTester{.type = mojom::Conv2d::Kind::kDirect,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 1, 3, 3}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
    OperandId filter_operand_id =
        builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);

    builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
                        filter_operand_id, input_operand_id,
                        Conv2dTester::Conv2dAttributes{}, std::nullopt);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for filter operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
    OperandId filter_operand_id =
        builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);

    builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
                        filter_operand_id, filter_operand_id,
                        Conv2dTester::Conv2dAttributes{}, std::nullopt);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

TEST_F(WebNNGraphImplTest, ConvTranspose2dTest) {
  {
    // Test convTranspose2d with default attributes.
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 3, 3}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 5, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test convTranspose2d with input_layout = nhwc.
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 3, 1}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 3, 1}},
                 .input_operand_layout = InputOperandLayout::kNhwc,
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 5, 5, 1}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test convTranspose2d with padding = [1, 1, 1, 1].
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 5, 5}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .attributes = {.padding = {1, 1, 1, 1}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 5, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test convTranspose2d with strides = [2, 2].
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 3, 3}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 2, 3, 3}},
                 .attributes = {.strides = {2, 2}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 2, 7, 7}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test convTranspose2d with strides = [2, 2] and padding = [1, 1, 1,
    // 1].
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 3, 3}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .attributes = {.padding = {1, 1, 1, 1}, .strides = {2, 2}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 5, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test convTranspose2d with group = 3.
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 3, 3}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 3, 3}},
                 .attributes = {.groups = 3},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 5, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input is not a 4-D tensor.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the filter is not a 4-D tensor.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 3, 3}},
        .filter = {.type = OperandDataType::kFloat32, .dimensions = {1, 3, 3}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the number of input channels is not equal
    // to the number of filter input channels.
    Conv2dTester{.type = mojom::Conv2d::Kind::kTransposed,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 1, 3, 3}},
                 .filter = {.type = OperandDataType::kFloat32,
                            .dimensions = {3, 1, 3, 3}},
                 .attributes = {.groups = 3},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 5, 5}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the number of output channels doesn't
    // match the result of filter output channels multiplied by groups
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 3, 3}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .attributes = {.groups = 3},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the filter type doesn't match the input
    // type.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 3, 3}},
        .filter = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 3, 3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the bias type doesn't match input type.
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 3, 3}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .attributes = {.bias = OperandInfo{.type = OperandDataType::kInt32,
                                           .dimensions = {1}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the bias shape is not equal to
    // [output_channels].
    Conv2dTester{
        .type = mojom::Conv2d::Kind::kTransposed,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 3, 3}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3}},
        .attributes = {.bias = OperandInfo{.type = OperandDataType::kFloat32,
                                           .dimensions = {2}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32);
    OperandId filter_operand_id =
        builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);

    builder.BuildConv2d(mojom::Conv2d::Kind::kTransposed, input_operand_id,
                        filter_operand_id, input_operand_id,
                        Conv2dTester::Conv2dAttributes{}, std::nullopt);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for filter operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32);
    OperandId filter_operand_id =
        builder.BuildInput("filter", {1, 1, 3, 3}, OperandDataType::kFloat32);

    builder.BuildConv2d(mojom::Conv2d::Kind::kTransposed, input_operand_id,
                        filter_operand_id, filter_operand_id,
                        Conv2dTester::Conv2dAttributes{}, std::nullopt);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct CumulativeSumTester {
  OperandInfo input;
  uint32_t axis;
  std::optional<bool> exclusive;
  std::optional<bool> reversed;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildCumulativeSum(input_operand_id, output_operand_id, axis,
                               exclusive, reversed);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, CumulativeSumTest) {
  {
    // Test cumulativeSum operator with default exclusive and reversed values.
    CumulativeSumTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .axis = 0,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test cumulativeSum operator with exclusive and reversed.
    CumulativeSumTester{.input = {.type = OperandDataType::kFloat32,
                                  .dimensions = {1, 2, 3, 4}},
                        .axis = 0,
                        .exclusive = true,
                        .reversed = true,
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 2, 3, 4}},
                        .expected = true}
        .Test(*this);
  }
  {
    // Test cumulativeSum operator with axis=2.
    CumulativeSumTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .axis = 2,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is a scalar.
    CumulativeSumTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .axis = 0,
        .output = {.type = OperandDataType::kFloat16, .dimensions = {}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph with an invalid axis.
    CumulativeSumTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .axis = 3,
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when output type doesn't match input type.
    CumulativeSumTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .axis = 2,
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();

    GraphInfoBuilder builder(remote);
    uint32_t axis = 0;
    bool exclusive = false;
    bool reversed = false;
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 3, 3}, OperandDataType::kFloat32);
    builder.BuildCumulativeSum(input_operand_id, input_operand_id, axis,
                               exclusive, reversed);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct DequantizeLinearTester {
  OperandInfo input;
  OperandInfo scale;
  OperandInfo zero_point;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId scale_operand_id =
        builder.BuildInput("scale", scale.dimensions, scale.type);
    OperandId zero_point_operand_id = builder.BuildInput(
        "zero_point", zero_point.dimensions, zero_point.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
                                  zero_point_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, DequantizeLinearTest) {
  {
    // Test dequantizeLinear operator when the input, the scale and the
    // zero_point have the same shape.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test dequantizeLinear operator with a broadcastable scale.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test dequantizeLinear operator with a broadcastable scale.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 1}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph whose scale rank is not equal to input rank.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph with an invalid scale.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph with different scale_shape and zero_point_shape.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the zero_point datatype doesn't match the
    // input's datatype.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .zero_point = {.type = OperandDataType::kUint8,
                       .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output datatype doesn't match the
    // scale's datatype.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    DequantizeLinearTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kInt8);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
    OperandId zero_point_operand_id =
        builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
    builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
                                  zero_point_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the scale is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kInt8);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
    OperandId zero_point_operand_id =
        builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
    builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
                                  zero_point_operand_id, scale_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the zeroPoint is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kInt8);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
    OperandId zero_point_operand_id =
        builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
    builder.BuildDequantizeLinear(input_operand_id, scale_operand_id,
                                  zero_point_operand_id, zero_point_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ElementWiseBinaryTester {
  mojom::ElementWiseBinary::Kind kind;
  OperandInfo lhs;
  OperandInfo rhs;
  OperandInfo output;
  bool expected;

  static constexpr std::array<mojom::ElementWiseBinary::Kind, 16>
      kAllBinaryOps = {
          mojom::ElementWiseBinary::Kind::kAdd,
          mojom::ElementWiseBinary::Kind::kSub,
          mojom::ElementWiseBinary::Kind::kMul,
          mojom::ElementWiseBinary::Kind::kDiv,
          mojom::ElementWiseBinary::Kind::kPow,
          mojom::ElementWiseBinary::Kind::kMax,
          mojom::ElementWiseBinary::Kind::kMin,
          mojom::ElementWiseBinary::Kind::kEqual,
          mojom::ElementWiseBinary::Kind::kGreater,
          mojom::ElementWiseBinary::Kind::kGreaterOrEqual,
          mojom::ElementWiseBinary::Kind::kLesser,
          mojom::ElementWiseBinary::Kind::kLesserOrEqual,
          mojom::ElementWiseBinary::Kind::kNotEqual,
          mojom::ElementWiseBinary::Kind::kLogicalAnd,
          mojom::ElementWiseBinary::Kind::kLogicalOr,
          mojom::ElementWiseBinary::Kind::kLogicalXor,
  };

  static OperandDataType GetValidInputType(mojom::ElementWiseBinary::Kind op) {
    switch (op) {
      case mojom::ElementWiseBinary::Kind::kAdd:
      case mojom::ElementWiseBinary::Kind::kSub:
      case mojom::ElementWiseBinary::Kind::kMul:
      case mojom::ElementWiseBinary::Kind::kDiv:
      case mojom::ElementWiseBinary::Kind::kPow:
      case mojom::ElementWiseBinary::Kind::kMax:
      case mojom::ElementWiseBinary::Kind::kMin:
      case mojom::ElementWiseBinary::Kind::kEqual:
      case mojom::ElementWiseBinary::Kind::kGreater:
      case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
      case mojom::ElementWiseBinary::Kind::kLesser:
      case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
      case mojom::ElementWiseBinary::Kind::kNotEqual:
        return OperandDataType::kFloat32;
      case mojom::ElementWiseBinary::Kind::kLogicalAnd:
      case mojom::ElementWiseBinary::Kind::kLogicalOr:
      case mojom::ElementWiseBinary::Kind::kLogicalXor:
        return OperandDataType::kUint8;
    }
  }

  static OperandDataType GetValidOutputType(mojom::ElementWiseBinary::Kind op) {
    switch (op) {
      case mojom::ElementWiseBinary::Kind::kAdd:
      case mojom::ElementWiseBinary::Kind::kSub:
      case mojom::ElementWiseBinary::Kind::kMul:
      case mojom::ElementWiseBinary::Kind::kDiv:
      case mojom::ElementWiseBinary::Kind::kPow:
      case mojom::ElementWiseBinary::Kind::kMax:
      case mojom::ElementWiseBinary::Kind::kMin:
        return OperandDataType::kFloat32;
      case mojom::ElementWiseBinary::Kind::kEqual:
      case mojom::ElementWiseBinary::Kind::kGreater:
      case mojom::ElementWiseBinary::Kind::kGreaterOrEqual:
      case mojom::ElementWiseBinary::Kind::kLesser:
      case mojom::ElementWiseBinary::Kind::kLesserOrEqual:
      case mojom::ElementWiseBinary::Kind::kNotEqual:
      case mojom::ElementWiseBinary::Kind::kLogicalAnd:
      case mojom::ElementWiseBinary::Kind::kLogicalOr:
      case mojom::ElementWiseBinary::Kind::kLogicalXor:
        return OperandDataType::kUint8;
    }
  }

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId lhs_operand_id =
        builder.BuildInput("lhs", lhs.dimensions, lhs.type);
    OperandId rhs_operand_id =
        builder.BuildInput("rhs", rhs.dimensions, rhs.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id,
                                   output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }

  void TestLogicalOperators(WebNNGraphImplTest& test) {
    const mojom::ElementWiseBinary::Kind kLogicalOperators[] = {
        mojom::ElementWiseBinary::Kind::kEqual,
        mojom::ElementWiseBinary::Kind::kGreater,
        mojom::ElementWiseBinary::Kind::kGreaterOrEqual,
        mojom::ElementWiseBinary::Kind::kLesser,
        mojom::ElementWiseBinary::Kind::kLesserOrEqual,
        mojom::ElementWiseBinary::Kind::kNotEqual,
    };

    for (const auto& op : kLogicalOperators) {
      kind = op;
      Test(test);
    }
  }
};

TEST_F(WebNNGraphImplTest, ElementWiseBinaryTest) {
  for (const auto& op : ElementWiseBinaryTester::kAllBinaryOps) {
    const OperandDataType valid_input_type =
        ElementWiseBinaryTester::GetValidInputType(op);
    const OperandDataType valid_output_type =
        ElementWiseBinaryTester::GetValidOutputType(op);

  // Testing building with two input dimensions - {8, 1, 6, 1} and {7, 1, 5}.
  // Both the a and b dimensions have axes with length one that are expanded to
  // a larger size during the broadcast operation.
  // a_dimensions     (4d) 8 * 1 * 6 * 1
  // b_dimensions     (3d)     7 * 1 * 5
  // output_dimenions (4d) 8 * 7 * 6 * 5
    {
      ElementWiseBinaryTester{
          .kind = op,
          .lhs = {.type = valid_input_type, .dimensions = {8, 1, 6, 1}},
          .rhs = {.type = valid_input_type, .dimensions = {7, 1, 5}},
          .output = {.type = valid_output_type, .dimensions = {8, 7, 6, 5}},
          .expected = true}
          .Test(*this);
    }

    // Testing building with two input dimensions - {4, 2, 1} and {4}.
    // a_dimensions     (3d) 4 * 2 * 1
    // b_dimensions     (1d)         4
    // output_dimenions (3d) 4 * 2 * 4
    {
      ElementWiseBinaryTester{
          .kind = op,
          .lhs = {.type = valid_input_type, .dimensions = {4, 2, 1}},
          .rhs = {.type = valid_input_type, .dimensions = {4}},
          .output = {.type = valid_output_type, .dimensions = {4, 2, 4}},
          .expected = true}
          .Test(*this);
    }

    // Test the invalid graph for the input shapes are not broadcastable.
    {
      ElementWiseBinaryTester{
          .kind = op,
          .lhs = {.type = valid_input_type, .dimensions = {4, 2}},
          .rhs = {.type = valid_input_type, .dimensions = {4}},
          .output = {.type = valid_output_type, .dimensions = {4, 2}},
          .expected = false}
          .Test(*this);
    }

    // Test the invalid graph for the output shapes are not expected.
    {
      ElementWiseBinaryTester{
          .kind = op,
          .lhs = {.type = valid_input_type, .dimensions = {4, 2}},
          .rhs = {.type = valid_input_type, .dimensions = {4, 2}},
          .output = {.type = valid_output_type, .dimensions = {2}},
          .expected = false}
          .Test(*this);
    }

    // Test the invalid graph for input types don't match.
    {
      ElementWiseBinaryTester{
          .kind = op,
          .lhs = {.type = valid_input_type, .dimensions = {2}},
          .rhs = {.type = OperandDataType::kInt64, .dimensions = {2}},
          .output = {.type = valid_output_type, .dimensions = {2}},
          .expected = false}
          .Test(*this);
    }

    // Test the invalid graph for output types don't match.
    {
      ElementWiseBinaryTester{
          .kind = op,
          .lhs = {.type = valid_input_type, .dimensions = {2}},
          .rhs = {.type = valid_input_type, .dimensions = {2}},
          .output = {.type = OperandDataType::kInt64, .dimensions = {2}},
          .expected = false}
          .Test(*this);
    }
  }
}

struct ElementWiseUnaryTester {
  mojom::ElementWiseUnary::Kind kind;
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildElementWiseUnary(kind, input_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

// Test the data type support for element-wise unary operators.
// The data type support is defined in the first parameter of the tuple
// as a std::pair of mojom::ElementWiseUnary::Kind and array of
// datatypes supported by the operator.
class ElementWiseUnaryDataTypeFixture
    : public WebNNGraphImplTest,
      public testing::WithParamInterface<
          std::tuple<std::pair<mojom::ElementWiseUnary::Kind,
                               std::vector<OperandDataType>>,
                     OperandDataType,
                     OperandDataType>> {
 public:
  // Populate meaningful test suffixes.
  struct PrintToStringParamName {
    template <class ParamType>
    std::string operator()(
        const testing::TestParamInfo<ParamType>& info) const {
      std::string test_name =
          base::StrCat({OpKindToString(std::get<0>(info.param).first), "_",
                        DataTypeToString(std::get<1>(info.param)), "_",
                        DataTypeToString(std::get<2>(info.param))});
      return test_name;
    }
  };

  void TestDataTypeSupportWithDimensions(
      const std::vector<uint32_t>& dimensions) {
    auto [operator_trait, inputDataType, outputDataType] = GetParam();
    const mojom::ElementWiseUnary::Kind& kind = operator_trait.first;
    // Some operators support dissimilar input and output data types.
    const std::set<mojom::ElementWiseUnary::Kind>
        kOperatorsWithDissimilarDatatypeSupport = {
            mojom::ElementWiseUnary::Kind::kCast};

    // Check if data types match, or if the operator supports mismatch.
    // Check if the data type is supported by the operator.
    const bool expected =
        (inputDataType == outputDataType ||
         kOperatorsWithDissimilarDatatypeSupport.contains(kind)) &&
        base::Contains(operator_trait.second, inputDataType);

    ElementWiseUnaryTester{
        .kind = kind,
        .input = {.type = inputDataType, .dimensions = dimensions},
        .output = {.type = outputDataType, .dimensions = dimensions},
        .expected = expected}
        .Test(*this);
  }
};

TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandDataTypeSupport) {
  TestDataTypeSupportWithDimensions(std::vector<uint32_t>{1, 2, 3, 1});
}

TEST_P(ElementWiseUnaryDataTypeFixture, TestUnaryOperandScalarDataTypeSupport) {
  TestDataTypeSupportWithDimensions(std::vector<uint32_t>{});
}

INSTANTIATE_TEST_SUITE_P(
    WebNNGraphImplTest,
    ElementWiseUnaryDataTypeFixture,
    ::testing::Combine(
        ::testing::ValuesIn({
            std::make_pair(mojom::ElementWiseUnary::Kind::kLogicalNot,
                           std::vector<OperandDataType>{
                               OperandDataType::kUint8}),
            std::make_pair(
                mojom::ElementWiseUnary::Kind::kIdentity,
                std::vector<OperandDataType>(kAllOperandDataTypes,
                                             std::end(kAllOperandDataTypes))),
            std::make_pair(mojom::ElementWiseUnary::Kind::kSqrt,
                           std::vector<OperandDataType>{
                               OperandDataType::kFloat16,
                               OperandDataType::kFloat32}),
            std::make_pair(mojom::ElementWiseUnary::Kind::kErf,
                           std::vector<OperandDataType>{
                               OperandDataType::kFloat16,
                               OperandDataType::kFloat32}),
            std::make_pair(mojom::ElementWiseUnary::Kind::kReciprocal,
                           std::vector<OperandDataType>{
                               OperandDataType::kFloat16,
                               OperandDataType::kFloat32}),
            std::make_pair(
                mojom::ElementWiseUnary::Kind::kCast,
                std::vector<OperandDataType>(kAllOperandDataTypes,
                                             std::end(kAllOperandDataTypes))),
        }),
        ::testing::ValuesIn(kAllOperandDataTypes),
        ::testing::ValuesIn(kAllOperandDataTypes)),
    ElementWiseUnaryDataTypeFixture::PrintToStringParamName());

TEST_F(WebNNGraphImplTest, ElementWiseUnaryTest) {
  {
    // Test building element-wise abs.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kAbs,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise ceil.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kCeil,
        .input = {.type = OperandDataType::kFloat16, .dimensions = {1}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {1}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise cos.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kCos,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise exp.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kExp,
        .input = {.type = OperandDataType::kFloat16, .dimensions = {1, 2}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {1, 2}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise floor.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kFloor,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise log.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kLog,
        .input = {.type = OperandDataType::kFloat16, .dimensions = {1, 2, 3}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {1, 2, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise neg.
    ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kNeg,
                           .input = {.type = OperandDataType::kFloat32,
                                     .dimensions = {1, 2, 3, 4}},
                           .output = {.type = OperandDataType::kFloat32,
                                      .dimensions = {1, 2, 3, 4}},
                           .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise sin.
    ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kSin,
                           .input = {.type = OperandDataType::kFloat16,
                                     .dimensions = {1, 2, 3, 4}},
                           .output = {.type = OperandDataType::kFloat16,
                                      .dimensions = {1, 2, 3, 4}},
                           .expected = true}
        .Test(*this);
  }
  {
    // Test building element-wise tan.
    ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kTan,
                           .input = {.type = OperandDataType::kFloat32,
                                     .dimensions = {1, 2, 3, 4, 5}},
                           .output = {.type = OperandDataType::kFloat32,
                                      .dimensions = {1, 2, 3, 4, 5}},
                           .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid element-wise abs graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kAbs,
        .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint32,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise neg graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kNeg,
        .input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise ceil graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kCeil,
        .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint32,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise cos graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kCos,
        .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint32,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise exp graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kExp,
        .input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise floor graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kFloor,
        .input = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise log graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kLog,
        .input = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise sin graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kSin,
        .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint32,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid element-wise tan graph for the input with
    // unsupported data type.
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kTan,
        .input = {.type = OperandDataType::kUint32, .dimensions = {1, 2, 3, 4}},
        .output = {.type = OperandDataType::kUint32,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input and output shapes don't match.
    ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kAbs,
                           .input = {.type = OperandDataType::kFloat32,
                                     .dimensions = {1, 2, 3, 4}},
                           .output = {.type = OperandDataType::kFloat32,
                                      .dimensions = {1, 2, 3, 4, 5}},
                           .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output type don't match.
    ElementWiseUnaryTester{.kind = mojom::ElementWiseUnary::Kind::kCeil,
                           .input = {.type = OperandDataType::kFloat32,
                                     .dimensions = {1, 2, 3, 4}},
                           .output = {.type = OperandDataType::kFloat16,
                                      .dimensions = {1, 2, 3, 4}},
                           .expected = false}
        .Test(*this);
  }
  // Test case for cast where dimensions don't match
  {
    ElementWiseUnaryTester{
        .kind = mojom::ElementWiseUnary::Kind::kCast,
        .input = {.type = OperandDataType::kUint8, .dimensions = {1, 2, 3, 1}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {1, 2, 3, 2}},
        .expected = false}
        .Test(*this);
  }
}

struct EluTester {
  OperandInfo input;
  OperandInfo output;
  float alpha = 1.0;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildElu(input_operand_id, output_operand_id, alpha);

    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, EluTest) {
  {
    // Test elu operator for 2-D tensor with float32 input.
    EluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the alpha is NAN.
    EluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
              .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
              .alpha = NAN,
              .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not as expected.
    EluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output data types which don't match.
    EluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
              .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
              .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input data type is not floating
    // point.
    EluTester{.input = {.type = OperandDataType::kInt32, .dimensions = {2}},
              .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
              .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    builder.BuildElu(input_operand_id, input_operand_id, /*alpha*/ 1.0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ExpandTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildExpand(input_operand_id, output_operand_id);

    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ExpandTest) {
  {
    // Test building expand with the output shapes that are the same as
    // input.
    ExpandTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building expand with the output shapes that are broadcastable.
    ExpandTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {3, 1, 5}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {3, 4, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building expand with the output shapes that are broadcastable
    // and the number of output shapes larger than input.
    ExpandTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 5}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input shapes are not the same as
    // output shape and not broadcastable.
    ExpandTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 6, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 3, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input shapes are not broadcastable.
    ExpandTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {5}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {5, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output data types which don't match.
    ExpandTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    builder.BuildExpand(input_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph that output rank exceeds limits.
    auto context_properties = GetContextPropertiesForTesting();
    static constexpr SupportedRanks kRankLimit = SupportedRanks::UpTo(4);
    context_properties.data_type_limits.expand_input.ranks.IntersectWith(
        kRankLimit);
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    OperandId output_operand_id = builder.BuildOutput(
        "output", {1, 1, 1, 1, 2}, OperandDataType::kFloat32);
    builder.BuildExpand(input_operand_id, output_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct GatherAttributes {
  OperandInfo indices;
  uint32_t axis;
};

struct GatherTester {
  OperandInfo input;
  GatherAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId indices_operand_id = builder.BuildInput(
        "indices", attributes.indices.dimensions, attributes.indices.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildGather(input_operand_id, indices_operand_id, output_operand_id,
                        attributes.axis);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GatherTest) {
  {
    // Test gather operator with 3-D input and 2-D indices.
    GatherTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {6, 7}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 6, 7, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for the axis is too large.
    GatherTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {6, 7}},
                       .axis = 3},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {3, 4, 5, 6, 7}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the indices data type is floating point.
    GatherTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kFloat16,
                                   .dimensions = {6, 7}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 6, 7, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the indices data type is not one of uint32,
    // int32 or int64.
    GatherTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kFloat32,
                                   .dimensions = {6, 7}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 6, 7, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    GatherTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {6, 7}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 4, 6, 7, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    GatherTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {6, 7}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {3, 6, 7, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output is as same as the input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {2}, OperandDataType::kUint32);
    builder.BuildGather(input_operand_id, indices_operand_id, input_operand_id,
                        /*axis*/ 0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is as same as the indices.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {3}, OperandDataType::kUint32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {3}, OperandDataType::kUint32);
    builder.BuildGather(input_operand_id, indices_operand_id,
                        indices_operand_id, /*axis*/ 0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct GatherElementsTester {
  OperandInfo input;
  GatherAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId indices_operand_id = builder.BuildInput(
        "indices", attributes.indices.dimensions, attributes.indices.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildGatherElements(input_operand_id, indices_operand_id,
                                output_operand_id, attributes.axis);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GatherElementsTest) {
  {
    // Test gatherElements with 4-D input and indices.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {3, 4, 5, 6}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {3, 4, 2, 6}},
                       .axis = 2},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 4, 2, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for the axis is greater than the rank of input.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {3, 4, 5}},
                       .axis = 3},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for indices has incorrect rank.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {3, 4}},
                       .axis = 2},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for indices has incorrect shape.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {3, 3, 5}},
                       .axis = 2},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for indices data type is floating point.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kFloat16,
                                   .dimensions = {3, 4, 5}},
                       .axis = 0},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shapes are not expected.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {3, 1, 5}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    GatherElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 4, 5}},
        .attributes = {.indices = {.type = OperandDataType::kUint32,
                                   .dimensions = {3, 1, 5}},
                       .axis = 1},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 1, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output is as same as the input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {2, 3}, OperandDataType::kUint32);
    builder.BuildGatherElements(input_operand_id, indices_operand_id,
                                input_operand_id,
                                /*axis=*/0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is as same as the indices.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {3}, OperandDataType::kUint32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {3}, OperandDataType::kUint32);
    builder.BuildGatherElements(input_operand_id, indices_operand_id,
                                indices_operand_id, /*axis=*/0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct GatherNDTester {
  OperandInfo input;
  OperandInfo indices;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId indices_operand_id =
        builder.BuildInput("indices", indices.dimensions, indices.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildGatherND(input_operand_id, indices_operand_id,
                          output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GatherNDTest) {
  {
    // Test gatherND with 4-D input 3-D indices.
    GatherNDTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {3, 4, 5, 6}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {3, 7, 2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 7, 5, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input is a scalar.
    GatherNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the indices is a scalar.
    GatherNDTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 4, 5}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 4, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for indices.shape[-1] is greater than the input
    // rank.
    GatherNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shapes are not expected.
    GatherNDTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    GatherNDTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 1}},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output is as same as the input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kUint32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32);
    builder.BuildGatherND(input_operand_id, indices_operand_id,
                          input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is as same as the indices.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 1}, OperandDataType::kUint32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32);
    builder.BuildGatherND(input_operand_id, indices_operand_id,
                          indices_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct GeluTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildGelu(input_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GeluTest) {
  {
    // Test gelu operator for 3-D tensor with float32 input.
    GeluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input has data type int32.
    GeluTester{.input = {.type = OperandDataType::kInt32, .dimensions = {}},
               .output = {.type = OperandDataType::kInt32, .dimensions = {}},
               .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    GeluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    GeluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
               .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
               .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input has the same id as the output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1}, OperandDataType::kFloat16);
    builder.BuildGelu(input_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct GemmTester {
  OperandInfo a;
  OperandInfo b;
  std::optional<OperandInfo> c;
  struct GemmAttributes {
    std::optional<OperandId> c_operand_id;
    float alpha = 1.0;
    float beta = 1.0;
    bool a_transpose = false;
    bool b_transpose = false;
  };
  GemmAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId a_operand_id = builder.BuildInput("a", a.dimensions, a.type);
    OperandId b_operand_id = builder.BuildInput("b", b.dimensions, b.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    if (c) {
      attributes.c_operand_id = builder.BuildInput("c", c->dimensions, c->type);
    }
    builder.BuildGemm(a_operand_id, b_operand_id, output_operand_id,
                      std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GemmTest) {
  {
    // Test building gemm with default option.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building gemm with aTranspose = true.
    // Transposed a_dimensions would be {3, 2} and it's compatible with
    // b_dimensions {2, 4}.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .attributes = {.a_transpose = true},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building gemm with bTranspose = true.
    // Transposed b_dimensions would be {3, 4} and it's compatible with
    // a_dimensions {2, 3}.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {4, 3}},
        .attributes = {.b_transpose = true},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building gemm with setting optional input C.
    // The output dimensions of a * b would be {2, 4} and c_dimensions {4}
    // is able to broadcast to {2, 4}.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .c = OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building gemm with two matrices - {2, 3} and {2, 4} that can't
    // be multiplied together due to incompatible dimensions.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building gemm with aTranspose = true, bTranspose = true.
    // The output dimensions of a * b would be {2, 4} and c_dimension {2, 3}
    // is incompatible with {2, 4}.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .c = OperandInfo{.type = OperandDataType::kFloat32,
                         .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building gemm with aTranspose = true, bTranspose = true.
    // Set optional input C with type = int32 and it mismatches with input
    // type float32.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {4, 3}},
        .c = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2, 4}},
        .attributes = {.a_transpose = true, .b_transpose = true},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph if the input is not floating point.
    GemmTester{
        .a = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    GemmTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
}

struct GruTester {
  struct GruAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    std::optional<OperandId> initial_hidden_state_operand_id;
    bool reset_after = true;
    bool return_sequence = false;
    mojom::RecurrentNetworkDirection direction =
        mojom::RecurrentNetworkDirection::kForward;
    mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn;
    std::vector<mojom::RecurrentNetworkActivation> activations = {
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh};
  };

  OperandInfo input;
  OperandInfo weight;
  OperandInfo recurrent_weight;
  uint32_t steps;
  uint32_t hidden_size;
  std::optional<OperandInfo> bias;
  std::optional<OperandInfo> recurrent_bias;
  std::optional<OperandInfo> initial_hidden_state;
  GruAttributes attributes;
  std::vector<OperandInfo> outputs;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);

    std::vector<OperandId> output_operand_ids;
    output_operand_ids.reserve(outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
      output_operand_ids.push_back(
          builder.BuildOutput(base::StringPrintf("output%zu", i),
                              outputs[i].dimensions, outputs[i].type));
    }

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }
    if (initial_hidden_state.has_value()) {
      attributes.initial_hidden_state_operand_id = builder.BuildInput(
          "initialHiddenState", initial_hidden_state->dimensions,
          initial_hidden_state->type);
    }

    builder.BuildGru(input_operand_id, weight_operand_id,
                     recurrent_weight_operand_id, std::move(output_operand_ids),
                     steps, hidden_size, std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GruTest) {
  {
    // Test the gru operator.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t num_directions = 2;
    GruTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias = OperandInfo{.type = OperandDataType::kFloat32,
                            .dimensions = {num_directions, 3 * hidden_size}},
        .recurrent_bias =
            OperandInfo{.type = OperandDataType::kFloat32,
                        .dimensions = {num_directions, 3 * hidden_size}},
        .initial_hidden_state =
            OperandInfo{
                .type = OperandDataType::kFloat32,
                .dimensions = {num_directions, batch_size, hidden_size}},
        .attributes = {.reset_after = true,
                       .return_sequence = true,
                       .direction = mojom::RecurrentNetworkDirection::kBoth},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {steps, num_directions, batch_size,
                                    hidden_size}}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of weight is incorrect.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t num_directions = 1;
    GruTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 4 * hidden_size, input_size}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size}}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output shape is incorrect.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t num_directions = 1;
    GruTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size,
                                    3 * hidden_size}}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output number is incorrect.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t num_directions = 1;
    GruTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {steps, num_directions, batch_size,
                                    hidden_size}}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the initial hidden state has the same id as
    // one of the outputs.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t num_directions = 1;

    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildInput(
        "input", {steps, batch_size, input_size}, OperandDataType::kFloat32);
    OperandId weight_operand_id = builder.BuildInput(
        "weight", {num_directions, 3 * hidden_size, input_size},
        OperandDataType::kFloat32);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", {num_directions, 3 * hidden_size, hidden_size},
        OperandDataType::kFloat32);

    OperandId initial_hidden_state_operand_id = builder.BuildInput(
        "initialHiddenState", {num_directions, batch_size, hidden_size},
        OperandDataType::kFloat32);

    builder.BuildGru(
        input_operand_id, weight_operand_id, recurrent_weight_operand_id,
        {initial_hidden_state_operand_id}, steps, hidden_size,
        GruTester::GruAttributes{.initial_hidden_state_operand_id =
                                     initial_hidden_state_operand_id});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct GruCellTester {
  struct GruCellAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    bool reset_after = true;
    mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn;
    std::vector<mojom::RecurrentNetworkActivation> activations = {
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh};
  };

  OperandInfo input;
  OperandInfo weight;
  OperandInfo recurrent_weight;
  OperandInfo hidden_state;
  uint32_t hidden_size;
  std::optional<OperandInfo> bias;
  std::optional<OperandInfo> recurrent_bias;
  GruCellAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
    OperandId hidden_state_operand_id = builder.BuildInput(
        "hiddenState", hidden_state.dimensions, hidden_state.type);

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    builder.BuildGruCell(input_operand_id, weight_operand_id,
                         recurrent_weight_operand_id, hidden_state_operand_id,
                         output_operand_id, hidden_size, std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, GruCellTest) {
  uint32_t batch_size = 2;
  uint32_t input_size = 4;
  uint32_t hidden_size = 6;

  OperandInfo valid_input = {.type = OperandDataType::kFloat32,
                             .dimensions = {batch_size, input_size}};
  OperandInfo valid_weight = {.type = OperandDataType::kFloat32,
                              .dimensions = {3 * hidden_size, input_size}};
  OperandInfo valid_recurrent_weight = {
      .type = OperandDataType::kFloat32,
      .dimensions = {3 * hidden_size, hidden_size}};
  OperandInfo valid_hidden_state = {.type = OperandDataType::kFloat32,
                                    .dimensions = {batch_size, hidden_size}};
  OperandInfo valid_bias = {.type = OperandDataType::kFloat32,
                            .dimensions = {3 * hidden_size}};
  OperandInfo valid_recurrent_bias = {.type = OperandDataType::kFloat32,
                                      .dimensions = {3 * hidden_size}};
  OperandInfo valid_output = {.type = OperandDataType::kFloat32,
                              .dimensions = {batch_size, hidden_size}};

  {
    // Test the valid gruCell operator.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the input is incorrect.
    GruCellTester{.input = {.type = OperandDataType::kInt8,
                            .dimensions = {batch_size, input_size}},
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the input is incorrect.
    GruCellTester{.input = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, input_size}},
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the input is incorrect.
    GruCellTester{.input = {.type = OperandDataType::kFloat32,
                            .dimensions = {input_size}},
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the weight is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = {.type = OperandDataType::kInt8,
                             .dimensions = {3 * hidden_size, input_size}},
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the weight is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {4 * hidden_size, input_size}},
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the weight is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {3 * hidden_size}},
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the recurrent weight is
    // incorrect.
    GruCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = {.type = OperandDataType::kInt8,
                             .dimensions = {3 * hidden_size, hidden_size}},
        .hidden_state = valid_hidden_state,
        .hidden_size = hidden_size,
        .bias = valid_bias,
        .recurrent_bias = valid_recurrent_bias,
        .attributes = {.reset_after = true},
        .output = valid_output,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the recurrent weight is
    // incorrect.
    GruCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {3 * hidden_size, input_size}},
        .hidden_state = valid_hidden_state,
        .hidden_size = hidden_size,
        .bias = valid_bias,
        .recurrent_bias = valid_recurrent_bias,
        .attributes = {.reset_after = true},
        .output = valid_output,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the recurrent weight is
    // incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = {.type = OperandDataType::kFloat32,
                                       .dimensions = {3 * hidden_size}},
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the hidden_size is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = 1000,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the bias is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = OperandInfo{.type = OperandDataType::kUint8,
                                      .dimensions = {3 * hidden_size}},
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the bias is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = OperandInfo{.type = OperandDataType::kFloat32,
                                      .dimensions = {4 * hidden_size}},
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the bias is incorrect.
    GruCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = valid_recurrent_weight,
        .hidden_state = valid_hidden_state,
        .hidden_size = hidden_size,
        .bias = OperandInfo{.type = OperandDataType::kFloat32,
                            .dimensions = {3 * hidden_size, hidden_size}},
        .recurrent_bias = valid_recurrent_bias,
        .attributes = {.reset_after = true},
        .output = valid_output,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the recurrent bias is
    // incorrect.
    GruCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = valid_recurrent_weight,
        .hidden_state = valid_hidden_state,
        .hidden_size = hidden_size,
        .bias = valid_bias,
        .recurrent_bias = OperandInfo{.type = OperandDataType::kUint8,
                                      .dimensions = {3 * hidden_size}},
        .attributes = {.reset_after = true},
        .output = valid_output,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the recurrent bias is incorrect.
    GruCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = valid_recurrent_weight,
        .hidden_state = valid_hidden_state,
        .hidden_size = hidden_size,
        .bias = valid_bias,
        .recurrent_bias = OperandInfo{.type = OperandDataType::kFloat32,
                                      .dimensions = {4 * hidden_size}},
        .attributes = {.reset_after = true},
        .output = valid_output,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the recurrent bias is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias =
                      OperandInfo{.type = OperandDataType::kFloat32,
                                  .dimensions = {3 * hidden_size, hidden_size}},
                  .attributes = {.reset_after = true},
                  .output = valid_output,
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output data type is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = {.type = OperandDataType::kInt32,
                             .dimensions = {batch_size, hidden_size}},
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output shape is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = {.type = OperandDataType::kFloat32,
                             .dimensions = {batch_size, 3 * hidden_size}},
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output rank is incorrect.
    GruCellTester{.input = valid_input,
                  .weight = valid_weight,
                  .recurrent_weight = valid_recurrent_weight,
                  .hidden_state = valid_hidden_state,
                  .hidden_size = hidden_size,
                  .bias = valid_bias,
                  .recurrent_bias = valid_recurrent_bias,
                  .attributes = {.reset_after = true},
                  .output = {.type = OperandDataType::kFloat32,
                             .dimensions = {hidden_size}},
                  .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the hidden state has the same id as the
    // output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildInput(
        "input", {batch_size, input_size}, OperandDataType::kFloat32);
    OperandId weight_operand_id = builder.BuildInput(
        "weight", {3 * hidden_size, input_size}, OperandDataType::kFloat32);
    OperandId recurrent_weight_operand_id =
        builder.BuildInput("recurrentWeight", {3 * hidden_size, hidden_size},
                           OperandDataType::kFloat32);

    OperandId hidden_state_operand_id = builder.BuildInput(
        "hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32);

    builder.BuildGruCell(input_operand_id, weight_operand_id,
                         recurrent_weight_operand_id, hidden_state_operand_id,
                         hidden_state_operand_id, hidden_size,
                         GruCellTester::GruCellAttributes{.reset_after = true});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct InstanceNormalizationTester {
  OperandInfo input;
  std::optional<OperandInfo> scale;
  std::optional<OperandInfo> bias;
  struct InstanceNormalizationAttributes {
    std::optional<OperandId> scale_operand_id;
    std::optional<OperandId> bias_operand_id;
    float epsilon = 1e-5;
  };
  InstanceNormalizationAttributes attributes;
  InputOperandLayout input_operand_layout = InputOperandLayout::kNchw;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    context_properties.input_operand_layout = input_operand_layout;

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    if (scale) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    builder.BuildInstanceNormalization(input_operand_id, output_operand_id,
                                       std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, InstanceNormalizationTest) {
  {
    // Test building instanceNormalization with default option.
    InstanceNormalizationTester{.input = {.type = OperandDataType::kFloat32,
                                          .dimensions = {1, 2, 3, 3}},
                                .output = {.type = OperandDataType::kFloat32,
                                           .dimensions = {1, 2, 3, 3}},
                                .expected = true}
        .Test(*this);
  }
  {
    // Test building instanceNormalization with layout = nhwc.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .scale =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .input_operand_layout = InputOperandLayout::kNhwc,
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building instanceNormalization with default layout = nchw.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .scale =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test instanceNormalization when input data type and scale data type
    // mismatched.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .scale =
            OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building instanceNormalization when the size of scale is not equal
    // to the size of the feature dimension of the input.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .scale =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test instanceNormalization when input data type and bias data type
    // mismatched.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .bias = OperandInfo{.type = OperandDataType::kInt32, .dimensions = {2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test building instanceNormalization when the size of bias is not equal
    // to the size of the feature dimension of the input.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .bias =
            OperandInfo{.type = OperandDataType::kFloat32, .dimensions = {2}},
        .input_operand_layout = InputOperandLayout::kNhwc,
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output type is not the same as input type.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 3}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shape is not the same as input shape.
    InstanceNormalizationTester{.input = {.type = OperandDataType::kFloat32,
                                          .dimensions = {1, 2, 3, 3}},
                                .output = {.type = OperandDataType::kFloat32,
                                           .dimensions = {1, 1, 3, 3}},
                                .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input is not a 4-D tensor.
    InstanceNormalizationTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    builder.BuildInstanceNormalization(
        input_operand_id, input_operand_id,
        InstanceNormalizationTester::InstanceNormalizationAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is the same as the scale.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2}, OperandDataType::kFloat32);

    InstanceNormalizationTester::InstanceNormalizationAttributes attributes;
    attributes.scale_operand_id = scale_operand_id;

    builder.BuildInstanceNormalization(input_operand_id, scale_operand_id,
                                       std::move(attributes));
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is the same as the bias.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId bias_operand_id =
        builder.BuildInput("bias", {2}, OperandDataType::kFloat32);

    InstanceNormalizationTester::InstanceNormalizationAttributes attributes;
    attributes.bias_operand_id = bias_operand_id;

    builder.BuildInstanceNormalization(input_operand_id, bias_operand_id,
                                       std::move(attributes));
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct LayerNormalizationTester {
  OperandInfo input;
  std::optional<OperandInfo> scale;
  std::optional<OperandInfo> bias;
  struct LayerNormalizationAttributes {
    std::optional<OperandId> scale_operand_id;
    std::optional<OperandId> bias_operand_id;
    std::vector<uint32_t> axes;
    float epsilon = 1e-5;
  };
  LayerNormalizationAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    if (scale.has_value()) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    builder.BuildLayerNormalization(input_operand_id, output_operand_id,
                                    std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, LayerNormalizationTest) {
  {
    // Test building layerNormalization with default option for scalar input.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .attributes = {.axes = {}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building layerNormalization with 4-D input.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat16,
                  .dimensions = {1, 2, 3, 4}},
        .scale = OperandInfo{.type = OperandDataType::kFloat16,
                             .dimensions = {3, 4}},
        .bias = OperandInfo{.type = OperandDataType::kFloat16,
                            .dimensions = {3, 4}},
        .attributes = {.axes = {2, 3}},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {1, 2, 3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is a scalar and the axes is not
    // empty.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .attributes = {.axes = {0}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input data type is int64.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kInt64, .dimensions = {1}},
        .attributes = {.axes = {}},
        .output = {.type = OperandDataType::kInt64, .dimensions = {1}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the axes have duplications.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .attributes = {.axes = {0, 0}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the axis is greater than the input rank.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .attributes = {.axes = {2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the bias type doesn't match the input type.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat16,
                  .dimensions = {1, 2, 3, 4}},
        .bias = OperandInfo{.type = OperandDataType::kFloat32,
                            .dimensions = {3, 4}},
        .attributes = {.axes = {2, 3}},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the scale shape doesn't match the reduction
    // dimensions.
    LayerNormalizationTester{
        .input = {.type = OperandDataType::kFloat16,
                  .dimensions = {1, 2, 3, 4}},
        .scale = OperandInfo{.type = OperandDataType::kFloat16,
                             .dimensions = {2, 3}},
        .attributes = {.axes = {2, 3}},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {1, 2, 3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    LayerNormalizationTester{.input = {.type = OperandDataType::kFloat16,
                                       .dimensions = {1, 2, 3, 4}},
                             .attributes = {.axes = {}},
                             .output = {.type = OperandDataType::kFloat16,
                                        .dimensions = {1, 2, 3, 3}},
                             .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output type doesn't match the input type.
    LayerNormalizationTester{.input = {.type = OperandDataType::kFloat16,
                                       .dimensions = {1, 2, 3, 4}},
                             .attributes = {.axes = {}},
                             .output = {.type = OperandDataType::kFloat32,
                                        .dimensions = {1, 2, 3, 4}},
                             .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output is the same as the input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    builder.BuildLayerNormalization(
        input_operand_id, input_operand_id,
        LayerNormalizationTester::LayerNormalizationAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is the same as the scale.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {1, 2, 3, 4}, OperandDataType::kFloat32);

    LayerNormalizationTester::LayerNormalizationAttributes attributes;
    attributes.scale_operand_id = scale_operand_id;
    attributes.axes = {0, 1, 2, 3};

    builder.BuildLayerNormalization(input_operand_id, scale_operand_id,
                                    std::move(attributes));
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the output is the same as the bias.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
    OperandId bias_operand_id =
        builder.BuildInput("bias", {1, 2, 3, 4}, OperandDataType::kFloat32);

    LayerNormalizationTester::LayerNormalizationAttributes attributes;
    attributes.bias_operand_id = bias_operand_id;
    attributes.axes = {0, 1, 2, 3};

    builder.BuildLayerNormalization(input_operand_id, bias_operand_id,
                                    std::move(attributes));
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct LstmTester {
  struct LstmAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    std::optional<OperandId> peephole_weight_operand_id;
    std::optional<OperandId> initial_hidden_state_operand_id;
    std::optional<OperandId> initial_cell_state_operand_id;
    bool return_sequence = false;
    mojom::RecurrentNetworkDirection direction =
        mojom::RecurrentNetworkDirection::kForward;
    mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg;
    std::vector<mojom::RecurrentNetworkActivation> activations = {
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh,
        mojom::RecurrentNetworkActivation::kTanh};
  };

  OperandInfo input;
  OperandInfo weight;
  OperandInfo recurrent_weight;
  uint32_t steps;
  uint32_t hidden_size;
  std::optional<OperandInfo> bias;
  std::optional<OperandInfo> recurrent_bias;
  std::optional<OperandInfo> peephole_weight;
  std::optional<OperandInfo> initial_hidden_state;
  std::optional<OperandInfo> initial_cell_state;
  LstmAttributes attributes;
  std::vector<OperandInfo> outputs;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);

    std::vector<OperandId> output_operand_ids;
    output_operand_ids.reserve(outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
      output_operand_ids.push_back(
          builder.BuildOutput(base::StringPrintf("output%zu", i),
                              outputs[i].dimensions, outputs[i].type));
    }

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }
    if (peephole_weight.has_value()) {
      attributes.peephole_weight_operand_id = builder.BuildInput(
          "peepholeWeight", peephole_weight->dimensions, peephole_weight->type);
    }
    if (initial_hidden_state.has_value()) {
      attributes.initial_hidden_state_operand_id = builder.BuildInput(
          "initialHiddenState", initial_hidden_state->dimensions,
          initial_hidden_state->type);
    }
    if (initial_cell_state.has_value()) {
      attributes.initial_cell_state_operand_id =
          builder.BuildInput("initialCellState", initial_cell_state->dimensions,
                             initial_cell_state->type);
    }

    builder.BuildLstm(input_operand_id, weight_operand_id,
                      recurrent_weight_operand_id,
                      std::move(output_operand_ids), steps, hidden_size,
                      std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, LstmTest) {
  {
    // Test the lstm operator.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t direction_count = 2;
    LstmTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {direction_count, 4 * hidden_size,
                                  input_size}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {direction_count, 4 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias = OperandInfo{.type = OperandDataType::kFloat32,
                            .dimensions = {direction_count, 4 * hidden_size}},
        .recurrent_bias =
            OperandInfo{.type = OperandDataType::kFloat32,
                        .dimensions = {direction_count, 4 * hidden_size}},
        .peephole_weight =
            OperandInfo{.type = OperandDataType::kFloat32,
                        .dimensions = {direction_count, 3 * hidden_size}},
        .initial_hidden_state =
            OperandInfo{
                .type = OperandDataType::kFloat32,
                .dimensions = {direction_count, batch_size, hidden_size}},
        .initial_cell_state =
            OperandInfo{
                .type = OperandDataType::kFloat32,
                .dimensions = {direction_count, batch_size, hidden_size}},
        .attributes = {.return_sequence = true,
                       .direction = mojom::RecurrentNetworkDirection::kBoth},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {steps, direction_count, batch_size,
                                    hidden_size}}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of weight is incorrect.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t direction_count = 1;
    LstmTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {direction_count, 4 * hidden_size, 1000}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {direction_count, 4 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size}}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output is incorrect.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t direction_count = 1;
    LstmTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {direction_count, 4 * hidden_size,
                                  input_size}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {direction_count, 4 * hidden_size,
                                            hidden_size}},
        .steps = steps,
        .hidden_size = hidden_size,
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, 1000}}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the recurrent weight has the same id as
    // one of the outputs.
    uint32_t steps = 2;
    uint32_t batch_size = 16;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t direction_count = 1;

    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildInput(
        "input", {steps, batch_size, input_size}, OperandDataType::kFloat32);
    OperandId weight_operand_id = builder.BuildInput(
        "weight", {direction_count, 4 * hidden_size, input_size},
        OperandDataType::kFloat32);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", {direction_count, 4 * hidden_size, hidden_size},
        OperandDataType::kFloat32);

    OperandId output_operand_id = builder.BuildOutput(
        "output", {direction_count, batch_size, hidden_size},
        OperandDataType::kFloat32);
    builder.BuildLstm(input_operand_id, weight_operand_id,
                      recurrent_weight_operand_id,
                      {output_operand_id, recurrent_weight_operand_id}, steps,
                      hidden_size, LstmTester::LstmAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the initial cell state has the same id as
    // one of the outputs.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 3;
    uint32_t hidden_size = 4;
    uint32_t direction_count = 1;

    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildInput(
        "input", {steps, batch_size, input_size}, OperandDataType::kFloat32);
    OperandId weight_operand_id = builder.BuildInput(
        "weight", {direction_count, 4 * hidden_size, input_size},
        OperandDataType::kFloat32);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", {direction_count, 4 * hidden_size, hidden_size},
        OperandDataType::kFloat32);

    OperandId initial_cell_state_operand_id = builder.BuildInput(
        "initialCellState", {direction_count, batch_size, hidden_size},
        OperandDataType::kFloat32);
    OperandId output_operand_id = builder.BuildOutput(
        "output", {direction_count, batch_size, hidden_size},
        OperandDataType::kFloat32);

    builder.BuildLstm(
        input_operand_id, weight_operand_id, recurrent_weight_operand_id,
        {initial_cell_state_operand_id, output_operand_id}, steps, hidden_size,
        LstmTester::LstmAttributes{.initial_cell_state_operand_id =
                                       initial_cell_state_operand_id});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct LstmCellTester {
  struct LstmCellAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    std::optional<OperandId> peephole_weight_operand_id;
    mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg;
    std::vector<mojom::RecurrentNetworkActivation> activations = {
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh,
        mojom::RecurrentNetworkActivation::kTanh};
  };

  OperandInfo input;
  OperandInfo weight;
  OperandInfo recurrent_weight;
  OperandInfo hidden_state;
  OperandInfo cell_state;
  uint32_t hidden_size;
  std::optional<OperandInfo> bias;
  std::optional<OperandInfo> recurrent_bias;
  std::optional<OperandInfo> peephole_weight;
  LstmCellAttributes attributes;
  std::vector<OperandInfo> outputs;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
    OperandId hidden_state_operand_id = builder.BuildInput(
        "hiddenState", hidden_state.dimensions, hidden_state.type);
    OperandId cell_state_operand_id =
        builder.BuildInput("cellState", cell_state.dimensions, cell_state.type);

    std::vector<OperandId> output_operand_ids;
    output_operand_ids.reserve(outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
      output_operand_ids.push_back(
          builder.BuildOutput(base::StringPrintf("output%zu", i),
                              outputs[i].dimensions, outputs[i].type));
    }

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }
    if (peephole_weight.has_value()) {
      attributes.peephole_weight_operand_id = builder.BuildInput(
          "peepholeWeight", peephole_weight->dimensions, peephole_weight->type);
    }

    builder.BuildLstmCell(input_operand_id, weight_operand_id,
                          recurrent_weight_operand_id, hidden_state_operand_id,
                          cell_state_operand_id, std::move(output_operand_ids),
                          hidden_size, std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, LstmCellTest) {
  uint32_t batch_size = 15;
  uint32_t input_size = 12;
  uint32_t hidden_size = 20;

  OperandInfo valid_input = {.type = OperandDataType::kFloat32,
                             .dimensions = {batch_size, input_size}};
  OperandInfo valid_weight = {.type = OperandDataType::kFloat32,
                              .dimensions = {4 * hidden_size, input_size}};
  OperandInfo valid_recurrent_weight = {
      .type = OperandDataType::kFloat32,
      .dimensions = {4 * hidden_size, hidden_size}};
  OperandInfo valid_hidden_state = {.type = OperandDataType::kFloat32,
                                    .dimensions = {batch_size, hidden_size}};
  OperandInfo valid_cell_state = {.type = OperandDataType::kFloat32,
                                  .dimensions = {batch_size, hidden_size}};
  OperandInfo valid_bias = {.type = OperandDataType::kFloat32,
                            .dimensions = {4 * hidden_size}};
  OperandInfo valid_recurrent_bias = {.type = OperandDataType::kFloat32,
                                      .dimensions = {4 * hidden_size}};
  OperandInfo valid_peephole_weight = {.type = OperandDataType::kFloat32,
                                       .dimensions = {3 * hidden_size}};
  std::vector<OperandInfo> valid_outputs = {
      {.type = OperandDataType::kFloat32,
       .dimensions = {batch_size, hidden_size}},
      {.type = OperandDataType::kFloat32,
       .dimensions = {batch_size, hidden_size}}};
  {
    // Test a valid lstmCell operator.
    LstmCellTester{.input = valid_input,
                   .weight = valid_weight,
                   .recurrent_weight = valid_recurrent_weight,
                   .hidden_state = valid_hidden_state,
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .bias = valid_bias,
                   .recurrent_bias = valid_recurrent_bias,
                   .peephole_weight = valid_peephole_weight,
                   .outputs = valid_outputs,
                   .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the input is not one of the
    // floating point types.
    LstmCellTester{.input = {.type = OperandDataType::kUint32,
                             .dimensions = {batch_size, input_size}},
                   .weight = valid_weight,
                   .recurrent_weight = valid_recurrent_weight,
                   .hidden_state = valid_hidden_state,
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .outputs = valid_outputs,
                   .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the weight is incorrect.
    LstmCellTester{.input = valid_input,
                   .weight = {.type = OperandDataType::kFloat16,
                              .dimensions = {4 * hidden_size, input_size}},
                   .recurrent_weight = valid_recurrent_weight,
                   .hidden_state = valid_hidden_state,
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .outputs = valid_outputs,
                   .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the recurrent weight is
    // incorrect.
    LstmCellTester{.input = valid_input,
                   .weight = valid_weight,
                   .recurrent_weight = {.type = OperandDataType::kFloat32,
                                        .dimensions = {4 * hidden_size}},
                   .hidden_state = valid_hidden_state,
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .outputs = valid_outputs,
                   .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the hidden state is incorrect.
    LstmCellTester{.input = valid_input,
                   .weight = valid_weight,
                   .recurrent_weight = valid_recurrent_weight,
                   .hidden_state = {.type = OperandDataType::kFloat32,
                                    .dimensions = {batch_size, 1000}},
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .outputs = valid_outputs,
                   .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of the cell state is incorrect.
    LstmCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = valid_recurrent_weight,
        .hidden_state = valid_hidden_state,
        .cell_state = {.type = OperandDataType::kFloat32,
                       .dimensions = {batch_size, hidden_size, 1000}},
        .hidden_size = hidden_size,
        .outputs = valid_outputs,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the bias incorrect.
    LstmCellTester{.input = valid_input,
                   .weight = valid_weight,
                   .recurrent_weight = valid_recurrent_weight,
                   .hidden_state = valid_hidden_state,
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .bias = OperandInfo{.type = OperandDataType::kUint32,
                                       .dimensions = {4 * hidden_size}},
                   .outputs = valid_outputs,
                   .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shape of the recurrent bias is incorrect.
    LstmCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = valid_recurrent_weight,
        .hidden_state = valid_hidden_state,
        .cell_state = valid_cell_state,
        .hidden_size = hidden_size,
        .recurrent_bias = OperandInfo{.type = OperandDataType::kFloat32,
                                      .dimensions = {1000}},
        .outputs = valid_outputs,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the data type of the peephole weight is
    // incorrect.
    LstmCellTester{
        .input = valid_input,
        .weight = valid_weight,
        .recurrent_weight = valid_recurrent_weight,
        .hidden_state = valid_hidden_state,
        .cell_state = valid_cell_state,
        .hidden_size = hidden_size,
        .peephole_weight = OperandInfo{.type = OperandDataType::kInt64,
                                       .dimensions = {3 * hidden_size}},
        .outputs = valid_outputs,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output data type is incorrect.
    LstmCellTester{.input = valid_input,
                   .weight = valid_weight,
                   .recurrent_weight = valid_recurrent_weight,
                   .hidden_state = valid_hidden_state,
                   .cell_state = valid_cell_state,
                   .hidden_size = hidden_size,
                   .outputs = {{.type = OperandDataType::kInt8,
                                .dimensions = {batch_size, hidden_size}},
                               {.type = OperandDataType::kInt8,
                                .dimensions = {batch_size, hidden_size}}},
                   .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the cell state has the same id as
    // one of the outputs.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildInput(
        "input", {batch_size, input_size}, OperandDataType::kFloat32);
    OperandId weight_operand_id = builder.BuildInput(
        "weight", {4 * hidden_size, input_size}, OperandDataType::kFloat32);
    OperandId recurrent_weight_operand_id =
        builder.BuildInput("recurrentWeight", {4 * hidden_size, hidden_size},
                           OperandDataType::kFloat32);
    OperandId hidden_state_operand_id = builder.BuildInput(
        "hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32);
    OperandId cell_state_operand_id = builder.BuildInput(
        "cellState", {batch_size, hidden_size}, OperandDataType::kFloat32);
    OperandId output_operand_id = builder.BuildOutput(
        "output", {batch_size, hidden_size}, OperandDataType::kFloat32);

    builder.BuildLstmCell(input_operand_id, weight_operand_id,
                          recurrent_weight_operand_id, hidden_state_operand_id,
                          cell_state_operand_id,
                          {cell_state_operand_id, output_operand_id},
                          hidden_size, LstmTester::LstmAttributes{});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct MatmulTester {
  OperandInfo a;
  OperandInfo b;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId a_operand_id = builder.BuildInput("a", a.dimensions, a.type);
    OperandId b_operand_id = builder.BuildInput("b", b.dimensions, b.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    builder.BuildMatmul(a_operand_id, b_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, MatmulTest) {
  {
    // Test building matmul with 2-D * 2-D.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building matmul with 2-D * 4-D.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 3, 4}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {2, 3, 2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test building matmul with 3-D * 4-D using broadcasting.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 3, 4}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 2, 2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for one input rank is smaller than 2.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the number of columns in first matrix
    // mismatches with the number of rows in second matrix.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input shapes are not broadcastable.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph if the input is not floating point.
    MatmulTester{
        .a = {.type = OperandDataType::kUint8, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kUint8, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input types are not same.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kInt32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output type is not the same as input type.
    MatmulTester{
        .a = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .b = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output is as same as one input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId a_operand_id =
        builder.BuildInput("a", {2, 3}, OperandDataType::kFloat32);
    OperandId b_operand_id =
        builder.BuildInput("b", {3, 4}, OperandDataType::kFloat32);
    builder.BuildMatmul(a_operand_id, b_operand_id, a_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct PadTester {
  OperandInfo input;
  std::vector<uint32_t> beginning_padding;
  std::vector<uint32_t> ending_padding;
  mojom::PaddingMode::Tag mode = mojom::PaddingMode::Tag::kConstant;
  float value = 0;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildPad(input_operand_id, output_operand_id, beginning_padding,
                     ending_padding, mode, value);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, PadTest) {
  {
    // Test pad with default options, beginningPadding = {1, 2} and
    // endingPadding = {1, 2}.
    PadTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .beginning_padding = {1, 2},
        .ending_padding = {1, 2},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test pad with mode = "edge", beginningPadding = {1, 2} and
    // endingPadding = {1, 2}.
    PadTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .beginning_padding = {1, 2},
        .ending_padding = {1, 2},
        .mode = mojom::PaddingMode::Tag::kEdge,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test pad with value = 1, beginningPadding = {1, 2} and
    // endingPadding = {1, 2}.
    PadTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .beginning_padding = {1, 2},
        .ending_padding = {1, 2},
        .value = 1,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test pad with value = NAN, beginningPadding = {1, 2} and
    // endingPadding = {1, 2}.
    PadTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .beginning_padding = {1, 2},
        .ending_padding = {1, 2},
        .value = NAN,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the length of beginningPadding is not
    // equal to the input rank.
    PadTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .beginning_padding = {1},
        .ending_padding = {1, 2},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the length of endingPadding is not equal
    // to the input rank.
    PadTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .beginning_padding = {1, 0},
        .ending_padding = {1, 2, 0},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 7}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    builder.BuildPad(input_operand_id, input_operand_id, {1, 1}, {1, 1},
                     mojom::PaddingMode::Tag::kConstant, 0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct Pool2dTester {
  OperandInfo input;
  struct Pool2dAttributes {
    std::vector<uint32_t> window_dimensions;
    std::vector<uint32_t> padding = {0, 0, 0, 0};
    std::vector<uint32_t> strides = {1, 1};
    std::vector<uint32_t> dilations = {1, 1};
  };
  Pool2dAttributes attributes;
  InputOperandLayout input_operand_layout = InputOperandLayout::kNchw;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    Test(test, mojom::Pool2d::Kind::kAveragePool2d);
    Test(test, mojom::Pool2d::Kind::kL2Pool2d);
    Test(test, mojom::Pool2d::Kind::kMaxPool2d);
  }

  void Test(WebNNGraphImplTest& test, mojom::Pool2d::Kind kind) {
    auto context_properties = GetContextPropertiesForTesting();
    context_properties.input_operand_layout = input_operand_layout;

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildPool2d(kind, input_operand_id, output_operand_id,
                        std::move(attributes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, Pool2dTest) {
  {
    // Test pool2d with default attributes.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 4, 4}},
                 .attributes = {.window_dimensions = {1, 1}, .strides = {1, 1}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 4, 4}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test pool2d with window dimensions.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 5, 5}},
                 .attributes = {.window_dimensions = {2, 2}, .strides = {2, 2}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 3, 3}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test pool2d with strides=2, padding=1 and floor rounding.
    Pool2dTester{.input = {.type = OperandDataType::kFloat16,
                           .dimensions = {1, 3, 7, 7}},
                 .attributes = {.window_dimensions = {4, 4},
                                .padding = {1, 1, 1, 1},
                                .strides = {2, 2}},
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 3, 3, 3}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test pool2d with strides=2, padding=1 and ceil rounding.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 7, 7}},
                 .attributes = {.window_dimensions = {4, 4},
                                .padding = {1, 1, 1, 1},
                                .strides = {2, 2}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 4, 4}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test pool2d with layout="nhwc".
    Pool2dTester{.input = {.type = OperandDataType::kFloat16,
                           .dimensions = {1, 5, 5, 2}},
                 .attributes = {.window_dimensions = {3, 3}, .strides = {1, 1}},
                 .input_operand_layout = InputOperandLayout::kNhwc,
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 3, 3, 2}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is not a 4-D tensor.
    Pool2dTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 5, 5}},
        .attributes = {.window_dimensions = {5, 5},
                       .padding = {2, 2, 2, 2},
                       .strides = {1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 5, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when window dimensions are 0.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 4, 4}},
                 .attributes = {.window_dimensions = {0, 0}, .strides = {1, 1}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 4, 4}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when strides are 0.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 4, 4}},
                 .attributes = {.window_dimensions = {1, 1}, .strides = {0, 0}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 4, 4}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when dilations are 0.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 4, 4}},
                 .attributes = {.window_dimensions = {1, 1},
                                .strides = {1, 1},
                                .dilations = {0, 0}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 4, 4}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 4, 4}},
                 .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 2, 1, 1}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    Pool2dTester{.input = {.type = OperandDataType::kFloat32,
                           .dimensions = {1, 3, 4, 4}},
                 .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
                 .output = {.type = OperandDataType::kFloat16,
                            .dimensions = {1, 3, 1, 1}},
                 .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph if the input data type is not floating point for
    // averagePool2d.
    Pool2dTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 4, 4}},
        .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 1, 1}},
        .expected = false}
        .Test(*this, mojom::Pool2d::Kind::kAveragePool2d);
  }
  {
    // Test the invalid graph if the input data type is not floating point for
    // l2Pool2d.
    Pool2dTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {1, 3, 4, 4}},
        .attributes = {.window_dimensions = {4, 4}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {1, 3, 1, 1}},
        .expected = false}
        .Test(*this, mojom::Pool2d::Kind::kL2Pool2d);
  }
}

struct PreluTester {
  OperandInfo input;
  OperandInfo slope;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId slope_operand_id =
        builder.BuildInput("slope", slope.dimensions, slope.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildPrelu(input_operand_id, slope_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, PreluTest) {
  {
    // Test prelu operator when the input and the slope have the same shape.
    PreluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test prelu operator with a broadcastable slope.
    PreluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph with an invalid slope.
    PreluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test prelu operator with input data type and slope data type = int32.
    PreluTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test prelu operator with input data type and slope data type = float16.
    PreluTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test prelu operator with input data type and slope data type = int8.
    PreluTester{
        .input = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the slope datatype doesn't match the
    // input's datatype.
    PreluTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input data type and slope data type =
    // uint32.
    PreluTester{
        .input = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kUint32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output datatype doesn't match the
    // input's datatype.
    PreluTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    PreluTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .slope = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 2, 6}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId slope_operand_id =
        builder.BuildInput("slope", {2, 3}, OperandDataType::kFloat32);
    builder.BuildPrelu(input_operand_id, slope_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the slope is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId output_operand_id =
        builder.BuildOutput("output", {2, 3}, OperandDataType::kFloat32);
    builder.BuildPrelu(input_operand_id, output_operand_id, output_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct QuantizeLinearTester {
  OperandInfo input;
  OperandInfo scale;
  OperandInfo zero_point;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId scale_operand_id =
        builder.BuildInput("scale", scale.dimensions, scale.type);
    OperandId zero_point_operand_id = builder.BuildInput(
        "zero_point", zero_point.dimensions, zero_point.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
                                zero_point_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, QuantizeLinearTest) {
  {
    // Test quantizeLinear operator when the input, the scale and the zero_point
    // have the same shape.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test quantizeLinear operator with a broadcastable scale.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {1, 1, 5}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test quantizeLinear operator with a broadcastable scale.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 1, 1}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 1, 1}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph whose scale rank is not equal to input rank.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph with an invalid scale.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {3, 5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {3, 5}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph with different scale_shape and zero_point_shape.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {2}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the scale datatype doesn't match the
    // input's datatype.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat16, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
        .output = {.type = OperandDataType::kInt8, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the output datatype doesn't match the
    // zero_point's datatype.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
        .output = {.type = OperandDataType::kUint8, .dimensions = {3, 2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    QuantizeLinearTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 2, 5}},
        .scale = {.type = OperandDataType::kFloat32, .dimensions = {5}},
        .zero_point = {.type = OperandDataType::kInt8, .dimensions = {5}},
        .output = {.type = OperandDataType::kUint8, .dimensions = {5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
    OperandId zero_point_operand_id =
        builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
    builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
                                zero_point_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the scale is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
    OperandId zero_point_operand_id =
        builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
    builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
                                zero_point_operand_id, scale_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the zeroPoint is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    OperandId scale_operand_id =
        builder.BuildInput("scale", {2, 3}, OperandDataType::kFloat32);
    OperandId zero_point_operand_id =
        builder.BuildInput("zero_point", {2, 3}, OperandDataType::kInt8);
    builder.BuildQuantizeLinear(input_operand_id, scale_operand_id,
                                zero_point_operand_id, zero_point_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ReduceTester {
  mojom::Reduce::Kind kind;
  OperandInfo input;
  std::vector<uint32_t> axes;
  bool keep_dimensions = false;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildReduce(kind, input_operand_id, output_operand_id, axes,
                        keep_dimensions);

    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ReduceTest) {
  {
    // Test reduce operator with axes = {0, 2} and keep_dimensions = true.
    ReduceTester{.kind = mojom::Reduce::Kind::kL1,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {2, 3, 4, 5}},
                 .axes = {0, 2},
                 .keep_dimensions = true,
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 3, 1, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test reduceL1 operator with input_data_type = int32.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kL1,
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3, 4, 5}},
        .axes = {0, 2},
        .keep_dimensions = true,
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 1, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test reduce operator with axes = {2} and keep_dimensions = false.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kL2,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {2, 3, 4, 5}},
        .axes = {2},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 5}},
        .expected = true}
        .Test(*this);
  }
  {
    ReduceTester{
        .kind = mojom::Reduce::Kind::kMin,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {2, 3, 4, 5}},
        .axes = {0, 1, 2, 3},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .expected = true}
        .Test(*this);
  }
  // Test reduceMin with input_data_type = int64.
  {
    ReduceTester{
        .kind = mojom::Reduce::Kind::kMin,
        .input = {.type = OperandDataType::kInt64, .dimensions = {2, 3, 4, 5}},
        .axes = {0, 1, 2, 3},
        .output = {.type = OperandDataType::kInt64, .dimensions = {}},
        .expected = true}
        .Test(*this);
  }
  // Test reduceSum with input_data_type = int64.
  {
    ReduceTester{
        .kind = mojom::Reduce::Kind::kSum,
        .input = {.type = OperandDataType::kInt64, .dimensions = {2, 3, 4, 5}},
        .axes = {0, 1, 2, 3},
        .output = {.type = OperandDataType::kInt64, .dimensions = {}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test reduce operator with empty axes = {}.
    ReduceTester{.kind = mojom::Reduce::Kind::kMin,
                 .input = {.type = OperandDataType::kFloat32,
                           .dimensions = {2, 3, 4, 5}},
                 .axes = {},
                 .output = {.type = OperandDataType::kFloat32,
                            .dimensions = {2, 3, 4, 5}},
                 .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of axes is larger than the
    // input rank.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kMax,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .axes = {0, 1, 2},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the axes contains duplicate values.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kMean,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .axes = {1, 1},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when one value in axes is greater than
    // input_rank - 1.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kSum,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .axes = {2},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shapes are not expected.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kProduct,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kLogSum,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of float types
    // for reduceLogSum.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kLogSum,
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt32, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of float types
    // for reduceLogSumExp.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kLogSumExp,
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt32, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of float types
    // for reduceL2.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kL2,
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt32, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of float types
    // for reduceMean.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kMean,
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt32, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of {float32,
    // float16, int32, uint32, int64, uint64} types for reduceProduce.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kProduct,
        .input = {.type = OperandDataType::kInt8, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt8, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of {float32,
    // float16, int32, uint32, int64, uint64} types for reduceL1.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kL1,
        .input = {.type = OperandDataType::kUint8, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kUint8, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of {float32,
    // float16, int32, uint32, int64, uint64} types for reduceSum.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kSum,
        .input = {.type = OperandDataType::kUint8, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kUint8, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type is not one of {float32,
    // float16, int32, uint32, int64, uint64} types for reduceSumSquare.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kSumSquare,
        .input = {.type = OperandDataType::kInt8, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt8, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the input type and the output type are not
    // same.
    ReduceTester{
        .kind = mojom::Reduce::Kind::kLogSum,
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .axes = {0},
        .keep_dimensions = false,
        .output = {.type = OperandDataType::kInt32, .dimensions = {3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2, 3}, OperandDataType::kFloat32);
    builder.BuildReduce(mojom::Reduce::Kind::kSumSquare, input_operand_id,
                        input_operand_id, {0}, false);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ReluTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildRelu(input_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ReluTest) {
  {
    // Test relu operator for 3-D tensor with float32 input.
    ReluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test relu operator for 4-D tensor with int32 input.
    ReluTester{.input = {.type = OperandDataType::kFloat32,
                         .dimensions = {1, 5, 3, 7}},
               .output = {.type = OperandDataType::kFloat32,
                          .dimensions = {1, 5, 3, 7}},
               .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph if the data type is not supported.
    ReluTester{
        .input = {.type = OperandDataType::kUint32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kUint32, .dimensions = {4, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    ReluTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    ReluTester{.input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
               .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
               .expected = false}
        .Test(*this);
  }
}

struct Resample2dTester {
  OperandInfo input;
  struct Resample2dAttributes {
    mojom::Resample2d::InterpolationMode mode =
        mojom::Resample2d::InterpolationMode::kNearestNeighbor;
    std::optional<std::vector<float>> scales;
    std::vector<uint32_t> axes = {2, 3};
  };
  Resample2dAttributes attributes;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildResample2d(input_operand_id, output_operand_id, attributes);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, Resample2dTest) {
  {
    // Test resample2d with "NearestNeighbor" mode and axes = [2, 3].
    Resample2dTester{.input = {.type = OperandDataType::kFloat32,
                               .dimensions = {1, 1, 2, 4}},
                     .output = {.type = OperandDataType::kFloat32,
                                .dimensions = {1, 1, 2, 4}},
                     .expected = true}
        .Test(*this);
  }
  {
    // Test resample2d with "Linear" mode, axes = [1, 2] and explicit scales
    // = [2, 2], input_data_type = float32.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 4, 1}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{2, 2},
                       .axes = {1, 2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 4, 8, 1}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test resample2d with "Linear" mode, axes = [1, 2] and explicit scales
    // = [2, 2], input_data_type = float16.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat16,
                  .dimensions = {1, 2, 4, 1}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{2, 2},
                       .axes = {1, 2}},
        .output = {.type = OperandDataType::kFloat16,
                   .dimensions = {1, 4, 8, 1}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test resample2d with "Linear" mode, axes = [1, 2] and explicit scales
    // = [2, 2.2] which is not exactly output dimensions / input dimensions.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 4, 1}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{2, 2.2},
                       .axes = {1, 2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 4, 8, 1}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    Resample2dTester{.input = {.type = OperandDataType::kFloat32,
                               .dimensions = {1, 1, 2, 4}},
                     .output = {.type = OperandDataType::kFloat16,
                                .dimensions = {1, 1, 4, 8}},
                     .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph if the input is not floating point.
    Resample2dTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 2, 4}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 1, 4, 8}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input is not a 4-D tensor.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output is not a 4-D tensor.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 1, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output dimensions that don't match the
    // calculated dimensions by scales.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 4, 1}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{2, 2},
                       .axes = {1, 2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 5, 8, 1}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the scale height is too large.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 34902, 23243}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{232433, 4}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the scale height is too small.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 2, 4}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{0.02, 0.8}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the scale width is too large.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 34902, 23243}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{20, 434324}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the scale width is too small.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 2, 4}},
        .attributes = {.mode = mojom::Resample2d::InterpolationMode::kLinear,
                       .scales = std::vector<float>{0.7, 0.1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the scales are negative.
    Resample2dTester{.input = {.type = OperandDataType::kFloat32,
                               .dimensions = {1, 1, 2, 4}},
                     .attributes{.scales = std::vector<float>{1.0, -2.0}},
                     .output = {.type = OperandDataType::kFloat32,
                                .dimensions = {1, 2, 4, 4}},
                     .expected = false}
        .Test(*this);
  }
  // Test when the dimensions of the input tensor to which
  // the interpolation algorithm applies are not two consecutive dimensions.
  {
    // With axes = [1, 3].
    Resample2dTester{.input = {.type = OperandDataType::kFloat32,
                               .dimensions = {1, 1, 2, 4}},
                     .attributes = {.axes = {1, 3}},
                     .output = {.type = OperandDataType::kFloat32,
                                .dimensions = {1, 2, 2, 8}},
                     .expected = true}
        .Test(*this);
  }
  // Test the invalid graph when the dimension of output doesn't equal to
  // the dimension of input except along the axes.
  {
    // With explicit scales.
    Resample2dTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 2, 4}},
        .attributes = {.scales = std::vector<float>{2, 2}, .axes = {2, 3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 4, 8}},
        .expected = false}
        .Test(*this);
  }
  {
    // Without explicit scales.
    Resample2dTester{.input = {.type = OperandDataType::kFloat32,
                               .dimensions = {1, 1, 2, 4}},
                     .attributes = {.axes = {2, 3}},
                     .output = {.type = OperandDataType::kFloat32,
                                .dimensions = {1, 2, 4, 8}},
                     .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 2, 4}, OperandDataType::kFloat32);
    builder.BuildResample2d(input_operand_id, input_operand_id,
                            Resample2dTester::Resample2dAttributes{});

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ReshapeTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildReshape(input_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ReshapeTest) {
  {
    // Test reshape operator from 2-D tensor to 1-D tensor.
    ReshapeTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {8}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test reshape operator from 4-D tensor to 2-D tensor.
    ReshapeTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {1, 3, 2, 1}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {1, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the number of input elements are not
    // equal to the number of output elements.
    ReshapeTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {3, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    ReshapeTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph that output rank exceeds limits.
    auto context_properties = GetContextPropertiesForTesting();
    static constexpr SupportedRanks kRankLimit = SupportedRanks::UpTo(4);
    context_properties.data_type_limits.reshape_input.ranks.IntersectWith(
        kRankLimit);
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    OperandId output_operand_id = builder.BuildOutput(
        "output", {2, 1, 1, 1, 1}, OperandDataType::kFloat32);
    builder.BuildReshape(input_operand_id, output_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ReverseTester {
  OperandInfo input;
  OperandInfo output;
  std::vector<uint32_t> axes;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildReverse(input_operand_id, output_operand_id, std::move(axes));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ReverseTest) {
  {
    // Test reverse operator.
    ReverseTester{.input = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 2, 3, 4}},
                  .output = {.type = OperandDataType::kFloat32,
                             .dimensions = {1, 2, 3, 4}},
                  .axes = {0, 1, 2},
                  .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the axes is duplicated.
    ReverseTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .axes = {1, 1, 2},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the axes is greater than input rank.
    ReverseTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .axes = {4},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    ReverseTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 4}},
        .axes = {0},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid reverse where the output is the same as the input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {3, 3}, OperandDataType::kFloat32);
    builder.BuildReverse(input_operand_id, input_operand_id, /*axes=*/{1});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ScatterElementsTester {
  OperandInfo input;
  OperandInfo indices;
  OperandInfo updates;
  OperandInfo output;
  uint32_t axis = 0;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId indices_operand_id =
        builder.BuildInput("indices", indices.dimensions, indices.type);
    OperandId updates_operand_id =
        builder.BuildInput("updates", updates.dimensions, updates.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildScatterElements(input_operand_id, indices_operand_id,
                                 updates_operand_id, output_operand_id, axis);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ScatterElementsTest) {
  {
    // ScatterElements to 2-D input along axis 0.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .axis = 0,
        .expected = true}
        .Test(*this);
  }
  {
    // ScatterElements to 2-D input along axis 1.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 5}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {1, 2}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {1, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 5}},
        .axis = 1,
        .expected = true}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements that axis is greater than input rank.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .axis = 2,
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements that the updates tensor data type is not
    // the same as input data type.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
        .updates = {.type = OperandDataType::kFloat16, .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements with scalar input, indices and updates
    // tensors.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements whose indices tensor rank is not the same
    // as input rank.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3, 3}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements whose indices size is not the same as
    // input size along axis 1.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 4}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .axis = 0,
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements whose indices size is not the same as
    // input size along axis 0.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 2}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .axis = 1,
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements whose updates tensor's shape is not the
    // same as indices tensor's.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements whose output shape is not the same as
    // input.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements whose output data type is not the same as
    // input.
    ScatterElementsTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {3, 3}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 3}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid ScatterElements where the output is the same as the
    // input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {3, 3}, OperandDataType::kFloat32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {2, 3}, OperandDataType::kUint32);
    OperandId updates_operand_id =
        builder.BuildInput("updates", {2, 3}, OperandDataType::kFloat32);
    builder.BuildScatterElements(input_operand_id, indices_operand_id,
                                 updates_operand_id, input_operand_id,
                                 /*axis=*/0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct ScatterNDTester {
  OperandInfo input;
  OperandInfo indices;
  OperandInfo updates;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId indices_operand_id =
        builder.BuildInput("indices", indices.dimensions, indices.type);
    OperandId updates_operand_id =
        builder.BuildInput("updates", updates.dimensions, updates.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildScatterND(input_operand_id, indices_operand_id,
                           updates_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ScatterNDTest) {
  {
    // Test a valid scatterND with 3-D input, 2-D indices and 3-D updates.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test an invalid scatterND that the updates tensor data type is not the
    // same as input data type.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
        .updates = {.type = OperandDataType::kFloat16, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND with scalar input tensor.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND with scalar indices tensor.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND that the size of last dimension of indices
    // tensor is greater than input rank.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 4}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND whose updates tensor shape is invalid.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
        // Updates tensor shape should be [2, 4, 4].
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND whose output shape is not the same as input.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND whose output data type is not the same as
    // input.
    ScatterNDTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4, 4}},
        .indices = {.type = OperandDataType::kUint32, .dimensions = {2, 1}},
        .updates = {.type = OperandDataType::kFloat32, .dimensions = {2, 4, 4}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {4, 4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test an invalid scatterND where the output is the same as the input.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {4, 4, 4}, OperandDataType::kFloat32);
    OperandId indices_operand_id =
        builder.BuildInput("indices", {2, 1}, OperandDataType::kUint32);
    OperandId updates_operand_id =
        builder.BuildInput("updates", {2, 4, 4}, OperandDataType::kFloat32);
    builder.BuildScatterND(input_operand_id, indices_operand_id,
                           updates_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct SliceTester {
  struct SliceAttributes {
    std::vector<uint32_t> starts;
    std::vector<uint32_t> sizes;
    std::vector<uint32_t> strides;
  };

  OperandInfo input;
  SliceAttributes attributes;
  OperandInfo output;

  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildSlice(input_operand_id, output_operand_id, attributes.starts,
                       attributes.sizes, attributes.strides);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, SliceTest) {
  {
    // Test slice with output dimensions equal to input dimensions.
    SliceTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .attributes = {.starts = {0, 0}, .sizes = {4, 4}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test 4x4 2-D Tensor to 2x2 slice
    SliceTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .attributes = {.starts = {0, 0}, .sizes = {2, 2}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test 4x4 2-D Tensor to 2x2 slice with offsets
    SliceTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .attributes = {.starts = {2, 2}, .sizes = {2, 2}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test that going out-of-bounds of the input tensor fails.
    SliceTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .attributes = {.starts = {1, 0}, .sizes = {1, 1}, .strides = {2, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test that mismatched output dimensions and size attribute will fail.
    SliceTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .attributes = {.starts = {0, 0}, .sizes = {1, 1}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 1}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test that having starts and sizes lengths not equal to the input rank
    // will fail.
    SliceTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .attributes = {.starts = {0}, .sizes = {4}, .strides = {1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test that input data type not equal to the output data type will
    // fail.
    SliceTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {4, 4}},
        .attributes = {.starts = {0, 0}, .sizes = {4, 4}, .strides = {1, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {4, 4}},
        .expected = false}
        .Test(*this);
  }
}

enum class FloatingPointUnaryKind {
  kHardSwish,
  kLeakyRelu,
  kLinear,
  kSigmoid,
  kTanh
};

struct FloatingPointUnaryTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    Test(test, FloatingPointUnaryKind::kHardSwish);
    Test(test, FloatingPointUnaryKind::kLeakyRelu);
    Test(test, FloatingPointUnaryKind::kLinear);
    Test(test, FloatingPointUnaryKind::kSigmoid);
    Test(test, FloatingPointUnaryKind::kTanh);
  }

  void Test(WebNNGraphImplTest& test, FloatingPointUnaryKind kind) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    switch (kind) {
      case FloatingPointUnaryKind::kHardSwish:
        builder.BuildHardSwish(input_operand_id, output_operand_id);
        break;
      case FloatingPointUnaryKind::kLeakyRelu:
        builder.BuildLeakyRelu(input_operand_id, output_operand_id,
                               /*alpha*/ 1.0);
        break;
      case FloatingPointUnaryKind::kLinear:
        builder.BuildLinear(input_operand_id, output_operand_id,
                            /*alpha*/ 1.0, /*beta*/ 0.0);
        break;
      case FloatingPointUnaryKind::kSigmoid:
        builder.BuildSigmoid(input_operand_id, output_operand_id);
        break;
      case FloatingPointUnaryKind::kTanh:
        builder.BuildTanh(input_operand_id, output_operand_id);
        break;
    }
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, FloatingPointUnaryTest) {
  {
    // Test the operator for 2-D tensor with float32 input.
    FloatingPointUnaryTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 6}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the operator for 3-D tensor with float16 input.
    FloatingPointUnaryTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {2, 6, 4}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 6, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not as expected.
    FloatingPointUnaryTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output data types which don't match.
    FloatingPointUnaryTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the input data type is not floating
    // point.
    FloatingPointUnaryTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for leaky relu when the input is as same as
    // output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    builder.BuildLeakyRelu(input_operand_id, input_operand_id,
                           /*alpha*/ 1.0);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for leaky relu when alpha is NAN.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    OperandId output_operand_id =
        builder.BuildOutput("output", {2}, OperandDataType::kFloat32);
    builder.BuildLeakyRelu(input_operand_id, output_operand_id,
                           /*alpha*/ NAN);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for linear when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    builder.BuildLinear(input_operand_id, input_operand_id,
                        /*alpha*/ 1.0, /*beta*/ 0.0);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for linear when alpha is NAN.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    OperandId output_operand_id =
        builder.BuildOutput("output", {2}, OperandDataType::kFloat32);
    builder.BuildLinear(input_operand_id, output_operand_id,
                        /*alpha*/ NAN, /*beta*/ 0.0);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for linear when beta is NAN.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    OperandId output_operand_id =
        builder.BuildOutput("output", {2}, OperandDataType::kFloat32);
    builder.BuildLinear(input_operand_id, output_operand_id,
                        /*alpha*/ 1.0, /*beta*/ NAN);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for sigmoid when the input is as same as
    // output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    builder.BuildSigmoid(input_operand_id, input_operand_id);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph for tanh when the input is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {2}, OperandDataType::kFloat32);
    builder.BuildTanh(input_operand_id, input_operand_id);

    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct SoftmaxTester {
  OperandInfo input;
  OperandInfo output;
  uint32_t axis;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildSoftmax(input_operand_id, output_operand_id, axis);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, SoftmaxTest) {
  {
    // Test softmax operator for input operand with [2, 2] dimensions.
    SoftmaxTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .axis = 1,
        .expected = true}
        .Test(*this);
  }
  {
    // Test softmax operator for input operand with [1, 4] dimensions.
    SoftmaxTester{
        .input = {.type = OperandDataType::kFloat16, .dimensions = {1, 4}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {1, 4}},
        .axis = 1,
        .expected = true}
        .Test(*this);
  }
  {
    // Test softmax operator for input operand with [1, 1, 4, 2] dimensions.
    SoftmaxTester{.input = {.type = OperandDataType::kFloat32,
                            .dimensions = {1, 1, 4, 2}},
                  .output = {.type = OperandDataType::kFloat32,
                             .dimensions = {1, 1, 4, 2}},
                  .axis = 3,
                  .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when building softmax with int32 input.
    SoftmaxTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {2, 3}},
        .axis = 1,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for axis is not less than the input rank.
    SoftmaxTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .axis = 2,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    SoftmaxTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .axis = 1,
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    SoftmaxTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
        .axis = 1,
        .expected = false}
        .Test(*this);
  }
}

struct SoftplusTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildSoftplus(input_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, SoftplusTest) {
  {
    // Test softplus operator.
    SoftplusTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for invalid data type.
    SoftplusTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    SoftplusTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    SoftplusTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
    builder.BuildSoftplus(input_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct SoftsignTester {
  OperandInfo input;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildSoftsign(input_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, SoftsignTest) {
  {
    // Test softsign operator with input dimensions = [2, 4] and data type
    // float32.
    SoftsignTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for invalid data type.
    SoftsignTester{
        .input = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kInt32, .dimensions = {4, 2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    SoftsignTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    SoftsignTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
    builder.BuildSoftsign(input_operand_id, input_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct SplitTester {
  OperandInfo input;
  std::vector<OperandInfo> outputs;
  uint32_t axis = 0;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);

    std::vector<OperandId> output_operand_ids;
    for (size_t i = 0; i < outputs.size(); ++i) {
      output_operand_ids.push_back(
          builder.BuildOutput("output" + base::NumberToString(i),
                              outputs[i].dimensions, outputs[i].type));
    }
    builder.BuildSplit(input_operand_id, output_operand_ids, axis);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, ValidateSplitTest) {
  using OperandDataType::kFloat32;
  {
    // Tests default axis split.
    SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}},
                .outputs = {{.type = kFloat32, .dimensions = {1, 2}},
                            {.type = kFloat32, .dimensions = {1, 2}}},
                .expected = true}
        .Test(*this);
  }
  {
    // Tests axis=1 split.
    SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}},
                .outputs = {{.type = kFloat32, .dimensions = {2, 1}},
                            {.type = kFloat32, .dimensions = {2, 1}}},
                .axis = 1,
                .expected = true}
        .Test(*this);
  }
  {
    // Tests for an invalid graph where not all output types match the input
    // type.
    SplitTester{
        .input = {.type = kFloat32, .dimensions = {2, 2}},
        .outputs = {{.type = kFloat32, .dimensions = {1, 2}},
                    {.type = OperandDataType::kFloat16, .dimensions = {1, 2}}},
        .expected = false}
        .Test(*this);
  }
  {
    // Tests for an invalid graph where the sum of the splits is less than
    // the input tensor size.
    SplitTester{.input = {.type = kFloat32, .dimensions = {2, 6}},
                .outputs = {{.type = kFloat32, .dimensions = {2, 1}},
                            {.type = kFloat32, .dimensions = {2, 2}},
                            {.type = kFloat32, .dimensions = {2, 2}}},
                .axis = 1,
                .expected = false}
        .Test(*this);
  }
  {
    // Tests for an invalid graph where the sum of the splits is greater
    // than the input tensor size.
    SplitTester{.input = {.type = kFloat32, .dimensions = {2, 6}},
                .outputs = {{.type = kFloat32, .dimensions = {2, 1}},
                            {.type = kFloat32, .dimensions = {2, 2}},
                            {.type = kFloat32, .dimensions = {2, 4}}},
                .axis = 1,
                .expected = false}
        .Test(*this);
  }
  {
    // Tests for an invalid graph where specified axis is greater then the
    // rank of the input tensor
    SplitTester{.input = {.type = kFloat32, .dimensions = {2, 2}},
                .outputs = {{.type = kFloat32, .dimensions = {1, 2}},
                            {.type = kFloat32, .dimensions = {1, 2}}},
                .axis = 2,
                .expected = false}
        .Test(*this);
  }
  {
    // Tests for an invalid graph where a split as specified along multiple
    // axis.
    SplitTester{.input = {.type = kFloat32, .dimensions = {4, 6}},
                .outputs = {{.type = kFloat32, .dimensions = {1, 2}},
                            {.type = kFloat32, .dimensions = {2, 3}},
                            {.type = kFloat32, .dimensions = {1, 1}}},
                .expected = false}
        .Test(*this);
  }
  {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildInput("input", {4, 6}, kFloat32);

    builder.BuildSplit(input_operand_id, {input_operand_id}, 0);
    builder.BuildSplit(input_operand_id,
                       {builder.BuildOutput("output", {4, 6}, kFloat32)}, 0);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct TileTester {
  OperandInfo input;
  std::vector<uint32_t> repetitions;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();

    // Build the graph with mojo type.
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildTile(input_operand_id, output_operand_id,
                      std::move(repetitions));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, TileTest) {
  {
    // Test tile operator with repetitions [2, 3, 1, 2].
    TileTester{.input = {.type = OperandDataType::kFloat32,
                         .dimensions = {1, 2, 3, 4}},
               .repetitions = {2, 3, 1, 2},
               .output = {.type = OperandDataType::kFloat32,
                          .dimensions = {2, 6, 3, 8}},
               .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the repetitions array is empty.
    TileTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .repetitions = {},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of repetitions is larger than
    // the input rank.
    TileTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .repetitions = {1, 1, 2, 2},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the repetitions contain zero value.
    TileTester{.input = {.type = OperandDataType::kFloat32,
                         .dimensions = {1, 2, 3, 4}},
               .repetitions = {0, 1, 2, 2},
               .output = {.type = OperandDataType::kFloat32,
                          .dimensions = {1, 2, 3, 3}},
               .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when any value in repetitions causes tiled
    // dimension size overflow.
    TileTester{.input = {.type = OperandDataType::kFloat32,
                         .dimensions = {1, 2, 34902, 4}},
               .repetitions = {1, 1, 232433, 2},
               .output = {.type = OperandDataType::kFloat32,
                          .dimensions = {1, 2, 2, 3}},
               .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shapes are not expected.
    TileTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 4}},
        .repetitions = {2, 1, 2, 3},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    TileTester{.input = {.type = OperandDataType::kFloat32,
                         .dimensions = {1, 2, 3, 4}},
               .repetitions = {0, 1, 2, 3},
               .output = {.type = OperandDataType::kFloat16,
                          .dimensions = {1, 2, 3, 4}},
               .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);
    builder.BuildTile(input_operand_id, input_operand_id,
                      std::vector<uint32_t>{1, 2});
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct TransposeTester {
  OperandInfo input;
  std::vector<uint32_t> permutation;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildTranspose(input_operand_id, output_operand_id,
                           std::move(permutation));
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, TransposeTest) {
  {
    // Test transpose operator with permutation [2, 3, 1, 0].
    TransposeTester{.input = {.type = OperandDataType::kFloat32,
                              .dimensions = {1, 2, 3, 4}},
                    .permutation = {2, 3, 1, 0},
                    .output = {.type = OperandDataType::kFloat32,
                               .dimensions = {3, 4, 2, 1}},
                    .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the rank of permutation is larger than
    // the input rank.
    TransposeTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .permutation = {0, 1, 2, 2},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the permutation contains duplicate
    // values.
    TransposeTester{.input = {.type = OperandDataType::kFloat32,
                              .dimensions = {1, 2, 3, 4}},
                    .permutation = {0, 1, 2, 2},
                    .output = {.type = OperandDataType::kFloat32,
                               .dimensions = {1, 2, 3, 3}},
                    .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when one value in permutation is greater than
    // input_rank - 1.
    TransposeTester{.input = {.type = OperandDataType::kFloat16,
                              .dimensions = {1, 2, 3, 4}},
                    .permutation = {0, 1, 2, 4},
                    .output = {.type = OperandDataType::kFloat16,
                               .dimensions = {1, 2, 3, 4}},
                    .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output shapes are not expected.
    TransposeTester{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 3, 4}},
        .permutation = {0, 1, 2, 3},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {1, 2, 3}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    TransposeTester{.input = {.type = OperandDataType::kFloat32,
                              .dimensions = {1, 2, 3, 4}},
                    .permutation = {0, 1, 2, 3},
                    .output = {.type = OperandDataType::kFloat16,
                               .dimensions = {1, 2, 3, 4}},
                    .expected = false}
        .Test(*this);
  }
}

struct TriangularTester {
  OperandInfo input;
  bool upper = true;
  int32_t diagonal = 0;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildTriangular(input_operand_id, output_operand_id, upper,
                            diagonal);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, TriangularTest) {
  {
    // Test triangular operator with upper = true and diagonal = 2.
    TriangularTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .upper = true,
        .diagonal = 2,
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 2}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph for the output shapes are not expected.
    TriangularTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {4, 2}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for output types don't match.
    TriangularTester{
        .input = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph for input operand == output operand.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {4, 6}, OperandDataType::kFloat32);

    builder.BuildTriangular(input_operand_id, input_operand_id,
                            /*upper*/ true, /*diagonal*/ -1);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

struct WhereTester {
  OperandInfo condition;
  OperandInfo true_value;
  OperandInfo false_value;
  OperandInfo output;
  bool expected;

  void Test(WebNNGraphImplTest& test) {
    auto context_properties = GetContextPropertiesForTesting();

    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId condition_operand_id =
        builder.BuildInput("condition", condition.dimensions, condition.type);
    OperandId true_value_operand_id = builder.BuildInput(
        "true_value", true_value.dimensions, true_value.type);
    OperandId false_value_operand_id = builder.BuildInput(
        "false_value", false_value.dimensions, false_value.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildWhere(condition_operand_id, true_value_operand_id,
                       false_value_operand_id, output_operand_id);
    EXPECT_EQ(builder.IsValidGraphForTesting(context_properties), expected);
  }
};

TEST_F(WebNNGraphImplTest, WhereTest) {
  {
    // Test the invalid graph when the condition data type is not uint8.
    WhereTester{
        .condition = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the the data types of true_value and
    // false_value don't match.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .false_value = {.type = OperandDataType::kFloat16,
                        .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the the data types of output and
    // true_value don't match.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat16, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the the shape of output is wrong.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 5}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the shapes of true_value and false_value
    // are not broadcastable.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 3}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test the invalid graph when the condition shape is not broadcastable.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 3}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 1}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = false}
        .Test(*this);
  }
  {
    // Test where with 2-D condition, 2-D true_value and 2-D false_value using
    // broadcast.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 1}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test where with 2-D condition, 2-D true_value and 3-D false_value using
    // broadcast.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {1, 4}},
        .true_value = {.type = OperandDataType::kFloat32, .dimensions = {3, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {2, 3, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test where with 3-D condition, 3-D true_value and 3-D false_value using
    // broadcast.
    WhereTester{
        .condition = {.type = OperandDataType::kUint8, .dimensions = {2, 1, 4}},
        .true_value = {.type = OperandDataType::kFloat32,
                       .dimensions = {2, 3, 4}},
        .false_value = {.type = OperandDataType::kFloat32,
                        .dimensions = {1, 4}},
        .output = {.type = OperandDataType::kFloat32, .dimensions = {2, 3, 4}},
        .expected = true}
        .Test(*this);
  }
  {
    // Test where with 4-D condition, 3-D true_value and 2-D false_value using
    // broadcast.
    WhereTester{.condition = {.type = OperandDataType::kUint8,
                              .dimensions = {2, 3, 4, 5}},
                .true_value = {.type = OperandDataType::kFloat32,
                               .dimensions = {3, 4, 5}},
                .false_value = {.type = OperandDataType::kFloat32,
                                .dimensions = {4, 5}},
                .output = {.type = OperandDataType::kFloat32,
                           .dimensions = {2, 3, 4, 5}},
                .expected = true}
        .Test(*this);
  }
  {
    // Test the invalid graph when the condition is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId condition_operand_id =
        builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8);
    OperandId true_value_operand_id =
        builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32);
    OperandId false_value_operand_id =
        builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32);
    builder.BuildWhere(condition_operand_id, true_value_operand_id,
                       false_value_operand_id, condition_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the true_value is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId condition_operand_id =
        builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8);
    OperandId true_value_operand_id =
        builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32);
    OperandId false_value_operand_id =
        builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32);
    builder.BuildWhere(condition_operand_id, true_value_operand_id,
                       false_value_operand_id, true_value_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
  {
    // Test the invalid graph when the false_value is as same as output.
    auto context_properties = GetContextPropertiesForTesting();
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId condition_operand_id =
        builder.BuildInput("condition", {2, 4}, OperandDataType::kUint8);
    OperandId true_value_operand_id =
        builder.BuildInput("true_value", {2, 4}, OperandDataType::kFloat32);
    OperandId false_value_operand_id =
        builder.BuildInput("false_value", {2, 4}, OperandDataType::kFloat32);
    builder.BuildWhere(condition_operand_id, true_value_operand_id,
                       false_value_operand_id, false_value_operand_id);
    EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
  }
}

TEST_F(WebNNGraphImplTest, ValidateDispatchTest) {
  auto context_properties = GetContextPropertiesForTesting();
  // TODO(crbug.com/325598628): De-dup these data type constants.
  const OperandDataType kMojoDataType = OperandDataType::kUint8;
  const OperandDataType kDataType = OperandDataType::kUint8;
  const std::vector<uint32_t> kShape = {3, 5};
  // Build the graph with mojo type.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  const OperandId lhs_operand_id =
      builder.BuildInput("lhs", kShape, kMojoDataType);
  const OperandId rhs_operand_id =
      builder.BuildInput("rhs", kShape, kMojoDataType);
  const OperandId output_1_operand_id =
      builder.BuildOutput("output1", kShape, kMojoDataType);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 lhs_operand_id, rhs_operand_id,
                                 output_1_operand_id);
  const OperandId output_2_operand_id =
      builder.BuildOutput("output2", kShape, kMojoDataType);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 lhs_operand_id, rhs_operand_id,
                                 output_2_operand_id);
  EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties));

  test::WebNNTestEnvironment webnn_test_enviroment;
  mojo::Remote<mojom::WebNNContextProvider> provider_remote;
  webnn_test_enviroment.BindWebNNContextProvider(
      provider_remote.BindNewPipeAndPassReceiver());

  {
    // Validate the inputs match the expected.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);

    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_TRUE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                 std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid inputs for invalid input size.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid outputs for invalid output size.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["a_different_output_name"] =
        CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid inputs for invalid input name.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["a_different_input_name"] =
        CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid outputs for invalid input name.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["a_different_output_name"] =
        CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid inputs for invalid first input shape.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, {2, 5});
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid inputs for invalid first input data type.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] =
        CreateWebNNTensor(webnn_context, OperandDataType::kInt8, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid outputs for invalid first output shape.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, {3, 4});
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid inputs for invalid second input data type.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] =
        CreateWebNNTensor(webnn_context, OperandDataType::kInt32, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid outputs for invalid second output shape.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, {2, 5});
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the inputs using the same tensor more than once.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = {/*webnn_tensor=*/std::nullopt, inputs["lhs"].webnn_handle};
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_TRUE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                 std::move(inputs), std::move(outputs)));
  }
  {
    // Test the invalid outputs when using the same tensor more than once.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = {/*webnn_tensor=*/std::nullopt,
                          outputs["output1"].webnn_handle};
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the inputs and outputs are invalid when using the same tensor.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = {/*webnn_tensor=*/std::nullopt,
                          inputs["lhs"].webnn_handle};
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the inputs are invalid when using a invalid tensor.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = {/*webnn_tensor=*/std::nullopt};
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
  {
    // Test the outputs are invalid when using a invalid tensor.
    mojo::Remote<mojom::WebNNContext> webnn_context =
        CreateWebNNContext(provider_remote);
    base::flat_map<std::string, CreateTensorSuccess> inputs;
    inputs["lhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    inputs["rhs"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    base::flat_map<std::string, CreateTensorSuccess> outputs;
    outputs["output1"] = CreateWebNNTensor(webnn_context, kDataType, kShape);
    outputs["output2"] = {/*webnn_tensor=*/std::nullopt};
    EXPECT_FALSE(ValidateDispatch(webnn_context, builder.CloneGraphInfo(),
                                  std::move(inputs), std::move(outputs)));
  }
}

// Test building a graph with two inputs and two constant in the following
// topology.
//    [input_a] [constant_a] [input_b] [constant_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplTest, BuildMultipleInputsAppendingConstants) {
  auto context_properties = GetContextPropertiesForTesting();
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  // The graph outputs are built first, and then inputs / constants.
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  std::vector<float> constant_data = {5.0, 6.0, 7.0, 8.0};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));

  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_a_operand_id, constant_a_operand_id,
                    intermediate_1_operand_id, GemmTester::GemmAttributes());

  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_b_operand_id, constant_b_operand_id,
                    intermediate_2_operand_id, GemmTester::GemmAttributes());
  builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
                    output_operand_id, GemmTester::GemmAttributes());
  EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties));
}

// Test building a graph with two inputs and two constant in the following
// topology.
//    [constant_a] [input_a] [constant_b] [input_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplTest, BuildMultipleConstantsAppendingInputs) {
  auto context_properties = GetContextPropertiesForTesting();
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  // The graph outputs are built first, and then inputs / constants.
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  std::vector<float> constant_data = {5.0, 6.0, 7.0, 8.0};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(constant_a_operand_id, input_a_operand_id,
                    intermediate_1_operand_id, GemmTester::GemmAttributes());

  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(constant_b_operand_id, input_b_operand_id,
                    intermediate_2_operand_id, GemmTester::GemmAttributes());

  builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
                    output_operand_id, GemmTester::GemmAttributes());
  EXPECT_TRUE(builder.IsValidGraphForTesting(context_properties));
}

TEST_F(WebNNGraphImplTest, BuildOperationWithNonexistentInputs) {
  auto context_properties = GetContextPropertiesForTesting();
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);

  OperandId intermediate_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kUint8);
  builder.BuildRelu(intermediate_operand_id, output_operand_id);
  builder.BuildRelu(input_operand_id, intermediate_operand_id);
  EXPECT_FALSE(builder.IsValidGraphForTesting(context_properties));
}

}  // namespace webnn