910e62b5创建于 1月15日历史提交
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <stdint.h>

#include <cmath>
#include <concepts>
#include <type_traits>

#include "base/compiler_specific.h"
#include "base/containers/fixed_flat_set.h"
#include "base/containers/flat_map.h"
#include "base/notreached.h"
#include "base/strings/string_number_conversions.h"
#include "base/test/bind.h"
#include "base/test/run_until.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "build/build_config.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/webnn/buildflags.h"
#include "services/webnn/public/cpp/webnn_types.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_context.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/public/mojom/webnn_graph_builder.mojom.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_test_environment.h"
#include "services/webnn/webnn_test_utils.h"
#include "services/webnn/webnn_utils.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/tokens/tokens.h"
#include "third_party/fp16/src/include/fp16.h"

#if BUILDFLAG(IS_WIN)
#include "base/containers/fixed_flat_map.h"
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/command_queue.h"
#include "services/webnn/dml/command_recorder.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/graph_impl_dml.h"
#include "services/webnn/dml/test_base.h"
#include "services/webnn/dml/utils.h"
#include "third_party/microsoft_dxheaders/include/directml.h"

// Windows SDK headers should be included after DirectX headers.
#include <wrl.h>

#endif  // BUILDFLAG(IS_WIN)

#if BUILDFLAG(IS_MAC)
#include "base/mac/mac_util.h"
#endif  // BUILDFLAG(IS_MAC)

namespace webnn::test {

namespace {

// TODO(crbug.com/373443096): Consolidate with the other Float16 types declared
// elsewhere.
struct Float16 {
  uint16_t data;
};

struct TensorRemoteAndHandle {
  mojo::AssociatedRemote<mojom::WebNNTensor> remote;
  blink::WebNNTensorToken handle;
};

TensorRemoteAndHandle CreateTensor(
    mojo::Remote<mojom::WebNNContext>& context_remote,
    mojom::TensorInfoPtr tensor_info) {
  mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote;

  base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future;
  context_remote->CreateTensor(std::move(tensor_info), mojo_base::BigBuffer(0),
                               create_tensor_future.GetCallback());
  mojom::CreateTensorResultPtr create_tensor_result =
      create_tensor_future.Take();
  EXPECT_TRUE(create_tensor_result->is_success());
  webnn_tensor_remote.Bind(
      std::move(create_tensor_result->get_success()->tensor_remote));
  EXPECT_TRUE(webnn_tensor_remote.is_bound());

  return TensorRemoteAndHandle{
      .remote = std::move(webnn_tensor_remote),
      .handle = create_tensor_result->get_success()->tensor_handle};
}

TensorRemoteAndHandle CreateTensorWithValues(
    mojo::Remote<mojom::WebNNContext>& context_remote,
    mojom::TensorInfoPtr tensor_info,
    base::span<const uint8_t> data) {
  auto remote_and_handle = CreateTensor(context_remote, std::move(tensor_info));
  remote_and_handle.remote->WriteTensor(mojo_base::BigBuffer(data));
  return remote_and_handle;
}

template <typename T>
std::vector<T> BigBufferToVector(const mojo_base::BigBuffer& big_buffer) {
  std::vector<T> data(big_buffer.size() / sizeof(T));
  UNSAFE_TODO(memcpy(data.data(), big_buffer.data(), big_buffer.size()));
  return data;
}

enum class BuildAndComputeExpectation { kSuccess, kCreateGraphFailure };

template <typename InputDataType, typename OutputDataType = InputDataType>
[[nodiscard]] base::flat_map<std::string, std::vector<OutputDataType>>
BuildAndCompute(
    mojo::Remote<mojom::WebNNContext>& context_remote,
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> graph_builder_remote,
    mojom::GraphInfoPtr graph_info,
    base::flat_map<std::string, base::span<const InputDataType>> named_inputs,
    BuildAndComputeExpectation expectation =
        BuildAndComputeExpectation::kSuccess) {
  // Create input tensors.
  std::vector<std::pair<std::string, TensorRemoteAndHandle>>
      named_input_remotes_and_handles;
  named_input_remotes_and_handles.reserve(graph_info->input_operands.size());

  for (OperandId operand_id : graph_info->input_operands) {
    const mojom::Operand& operand =
        *graph_info->operands.at(operand_id.value());
    EXPECT_TRUE(operand.name.has_value());

    auto it = named_inputs.find(*operand.name);
    EXPECT_TRUE(it != named_inputs.end());

    auto tensor_info = mojom::TensorInfo::New(
        operand.descriptor, MLTensorUsage{MLTensorUsageFlags::kWrite});
    base::span<const uint8_t> data;
    if constexpr (std::floating_point<InputDataType>) {
      // Floating point types do not have unique object representations, but
      // this code appears to be using a byte span to type-erase, which is fine.
      data = base::as_byte_span(base::allow_nonunique_obj, it->second);
    } else {
      data = base::as_byte_span(it->second);
    }
    named_input_remotes_and_handles.emplace_back(
        *operand.name,
        CreateTensorWithValues(context_remote, std::move(tensor_info), data));
  }

  // Create output tensors.
  std::vector<std::pair<std::string, TensorRemoteAndHandle>>
      named_output_remotes_and_handles;
  named_output_remotes_and_handles.reserve(graph_info->output_operands.size());

  for (OperandId operand_id : graph_info->output_operands) {
    const mojom::Operand& operand =
        *graph_info->operands.at(operand_id.value());
    EXPECT_TRUE(operand.name.has_value());

    auto tensor_info = mojom::TensorInfo::New(
        operand.descriptor, MLTensorUsage{MLTensorUsageFlags::kRead});
    named_output_remotes_and_handles.emplace_back(
        *operand.name, CreateTensor(context_remote, std::move(tensor_info)));
  }

  // The GraphImpl should be built successfully.
  base::test::TestFuture<
      base::expected<mojom::CreateGraphSuccessPtr, mojom::ErrorPtr>>
      create_graph_future;
  graph_builder_remote->CreateGraph(std::move(graph_info),
                                    create_graph_future.GetCallback());
  auto create_graph_result = create_graph_future.Take();

  switch (expectation) {
    case BuildAndComputeExpectation::kSuccess:
      EXPECT_TRUE(create_graph_result.has_value())
          << create_graph_result.error()->message;
      break;
    case BuildAndComputeExpectation::kCreateGraphFailure:
      EXPECT_FALSE(create_graph_result.has_value());
      return {};
  }

  mojo::AssociatedRemote<mojom::WebNNGraph> graph_remote;
  graph_remote.Bind(std::move(create_graph_result.value()->graph_remote));

  std::vector<std::pair<std::string, blink::WebNNTensorToken>>
      named_input_handles;
  named_input_handles.reserve(named_input_remotes_and_handles.size());
  std::ranges::transform(
      named_input_remotes_and_handles, std::back_inserter(named_input_handles),
      [](const auto& input) {
        return std::make_pair(input.first, input.second.handle);
      });

  std::vector<std::pair<std::string, blink::WebNNTensorToken>>
      named_output_handles;
  named_output_handles.reserve(named_output_remotes_and_handles.size());
  std::ranges::transform(
      named_output_remotes_and_handles,
      std::back_inserter(named_output_handles), [](const auto& output) {
        return std::make_pair(output.first, output.second.handle);
      });

  // The GraphImpl should compute successfully.
  graph_remote->Dispatch(named_input_handles, named_output_handles);

  // Read back the results from the output buffers.
  std::vector<std::pair<std::string, std::vector<OutputDataType>>>
      named_output_results;
  named_output_results.reserve(named_output_remotes_and_handles.size());
  for (auto& output : named_output_remotes_and_handles) {
    base::test::TestFuture<mojom::ReadTensorResultPtr> read_tensor_future;
    output.second.remote->ReadTensor(read_tensor_future.GetCallback());
    mojom::ReadTensorResultPtr result = read_tensor_future.Take();
    EXPECT_FALSE(result->is_error());
    named_output_results.emplace_back(
        output.first, BigBufferToVector<OutputDataType>(result->get_buffer()));
  }

  EXPECT_EQ(expectation, BuildAndComputeExpectation::kSuccess);

  return base::flat_map<std::string, std::vector<OutputDataType>>(
      std::move(named_output_results));
}

void VerifyFloatDataIsEqual(base::span<const float> data,
                            base::span<const float> expected_data) {
  float epsilon = 1e-5;
  EXPECT_THAT(data,
              testing::Pointwise(testing::FloatNear(epsilon), expected_data));
}

// Convert a vector of 32-bit floating-point data to a vector of 16-bit
// floating-point data, both in IEEE precision format.
std::vector<Float16> Float16FromFloat32(const std::vector<float>& fp32_data) {
  std::vector<Float16> fp16_data;
  fp16_data.reserve(fp32_data.size());

  for (size_t i = 0; i < fp32_data.size(); i++) {
    fp16_data.push_back(
        Float16{.data = fp16_ieee_from_fp32_value(fp32_data[i])});
  }

  return fp16_data;
}

// Convert a vector of 16-bit floating-point data to a vector of 32-bit
// floating-point data, both in IEEE precision format.
std::vector<float> Float16ToFloat32(const std::vector<Float16>& fp16_data) {
  std::vector<float> fp32_data;
  fp32_data.reserve(fp16_data.size());

  for (size_t i = 0; i < fp16_data.size(); i++) {
    fp32_data.push_back(fp16_ieee_to_fp32_value(fp16_data[i].data));
  }

  return fp32_data;
}

template <typename T>
struct OperandInfo {
  OperandDataType type;
  std::vector<uint32_t> dimensions;
  std::vector<T> values;
#if BUILDFLAG(IS_MAC)
  OperandInfo<int32_t> ToInt32() {
    return OperandInfo<int32_t>{
        .type = OperandDataType::kInt32,
        .dimensions = dimensions,
        .values = std::vector<int32_t>(values.begin(), values.end())};
  }
#endif  // BUILDFLAG(IS_MAC)
};

void VerifyIsEqual(base::span<const float> actual,
                   const OperandInfo<float>& expected) {
  VerifyFloatDataIsEqual(actual, expected.values);
}

template <typename T>
void VerifyIsEqual(base::span<const T> actual, const OperandInfo<T>& expected) {
  EXPECT_EQ(actual, expected.values);
}

}  // namespace

#if BUILDFLAG(IS_WIN)
class WebNNGraphImplBackendTest : public dml::TestBase {
 public:
  WebNNGraphImplBackendTest()
      : scoped_feature_list_(
            webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}

  void SetUp() override;
  void SetUpBase();
  void TearDown() override;

  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote();

  mojo::Remote<mojom::WebNNContext>& context() { return webnn_context_; }

 protected:
  base::test::ScopedFeatureList scoped_feature_list_;
  scoped_refptr<dml::Adapter> adapter_;

  WebNNTestEnvironment webnn_test_environment_;
  mojo::Remote<mojom::WebNNContextProvider> provider_remote_;
  mojo::Remote<mojom::WebNNContext> webnn_context_;
};

void WebNNGraphImplBackendTest::SetUp() {
  SKIP_TEST_IF(!dml::UseGPUInTests());

  dml::Adapter::EnableDebugLayerForTesting();
  auto adapter_creation_result = dml::Adapter::GetGpuInstanceForTesting();
  // If the adapter creation result has no value, it's most likely because
  // platform functions were not properly loaded.
  SKIP_TEST_IF(!adapter_creation_result.has_value());
  adapter_ = adapter_creation_result.value();
  // Graph compilation relies on IDMLDevice1::CompileGraph introduced in
  // DirectML version 1.2 or DML_FEATURE_LEVEL_2_1, so skip the tests if the
  // DirectML version doesn't support this feature.
  SKIP_TEST_IF(!adapter_->IsDMLDeviceCompileGraphSupportedForTesting());

  // Skip a test if the required feature level is not supported for the
  // operator being tested.
  auto kRequiredFeatureLevels = base::MakeFixedFlatMap<std::string_view,
                                                       DML_FEATURE_LEVEL>(
      {// DML_BATCHNORMALIZATION_OPERATOR_DESC support for 1~8 dimension counts
       // was introduced in DML_FEATURE_LEVEL_3_1.
       {"FuseStandaloneActivationIntoBatchNormalization",
        DML_FEATURE_LEVEL_3_1},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"FuseStandaloneActivationIntoGemm", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildAndComputeMultipleOperatorGemm", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildOneInputAndOneConstantOperand", DML_FEATURE_LEVEL_4_0},
       // DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC support for 1~8
       // dimension
       // counts was introduced in DML_FEATURE_LEVEL_3_1.
       {"BuildSingleOperatorLayerNormalization", DML_FEATURE_LEVEL_3_1},
       // DML_GEMM_OPERATOR_DESC support for 2~4 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"FuseStandaloneOperationsIntoMatmul", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMultipleInputsAppendingConstants", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMultipleConstantsAppendingInputs", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildGemmWithReshapedConstantOperand", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMaxPoolingAsThirdOperator", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMaxPoolingAsSecondOperator", DML_FEATURE_LEVEL_4_0},
       // DML_GEMM_OPERATOR_DESC support for 2 dimensions was introduced in
       // DML_FEATURE_LEVEL_4_0.
       {"BuildMaxPoolingAsFirstOperator", DML_FEATURE_LEVEL_4_0}});
  auto it = kRequiredFeatureLevels.find(
      ::testing::UnitTest::GetInstance()->current_test_info()->name());
  if (it != kRequiredFeatureLevels.end()) {
    const auto& required_feature_level = it->second;
    SKIP_TEST_IF(!adapter_->IsDMLFeatureLevelSupported(required_feature_level));
  }

  SetUpBase();
}
#endif  // #if BUILDFLAG(IS_WIN)

#if BUILDFLAG(IS_MAC)
class WebNNGraphImplBackendTest : public testing::Test {
 public:
  WebNNGraphImplBackendTest()
      : scoped_feature_list_(
            webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}

  void SetUp() override;
  void SetUpBase();
  void TearDown() override;

  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote();

  mojo::Remote<mojom::WebNNContext>& context() { return webnn_context_; }

 protected:
  base::test::ScopedFeatureList scoped_feature_list_;
  base::test::TaskEnvironment task_environment_;

  WebNNTestEnvironment webnn_test_environment_;
  mojo::Remote<mojom::WebNNContextProvider> provider_remote_;
  mojo::Remote<mojom::WebNNContext> webnn_context_;
};

void WebNNGraphImplBackendTest::SetUp() {
  if (base::mac::MacOSVersion() < 14'00'00) {
    GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS "
                 << base::mac::MacOSVersion();
  }
  const std::string_view current_test_name =
      ::testing::UnitTest::GetInstance()->current_test_info()->name();
  // Keep this list sorted by the operator being tested.
  static auto kSupportedTests = base::MakeFixedFlatSet<std::string_view>({
      "BuildAndComputeSingleOperatorClamp",
      "BuildAndComputeConcatWithConstants",
      "BuildAndComputeSingleOperatorRelu",
      "BuildAndComputeSingleOperatorTanh",
      "BuildAndComputeGraphWithTwoTranspose",
  });
  if (!kSupportedTests.contains(current_test_name)) {
    GTEST_SKIP() << "Skipping test because the operator is not yet supported.";
  }

  SetUpBase();
}
#endif  // BUILDFLAG(IS_MAC)

// TODO(crbug.com/325612086): Parameterize these tests for different backends.
#if BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_MAC) && !BUILDFLAG(IS_WIN)
class WebNNGraphImplBackendTest : public testing::Test {
 public:
  WebNNGraphImplBackendTest()
      : scoped_feature_list_(
            webnn::mojom::features::kWebMachineLearningNeuralNetwork) {}

  void SetUp() override;
  void SetUpBase();
  void TearDown() override;

  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> BindNewGraphBuilderRemote();

  mojo::Remote<mojom::WebNNContext>& context() { return webnn_context_; }

 protected:
  base::test::ScopedFeatureList scoped_feature_list_;
  base::test::TaskEnvironment task_environment_;

  WebNNTestEnvironment webnn_test_environment_;
  mojo::Remote<mojom::WebNNContextProvider> provider_remote_;
  mojo::Remote<mojom::WebNNContext> webnn_context_;
};

void WebNNGraphImplBackendTest::SetUp() {
  const std::string_view current_test_name =
      ::testing::UnitTest::GetInstance()->current_test_info()->name();
  // TODO: https://crbug.com/394119734 - Enable the commented-out tests after
  // fixing the bugs in the GPU delegate causing them to fail.
  static auto kSupportedTests = base::MakeFixedFlatSet<std::string_view>({
      "BuildAddWithReshapedConstantOperand",
      // "BuildAndComputeAddAndMulWithOnlyConstantInputs",
      // "BuildAndComputeAddWithOnlyConstantInputs",
      "BuildAndComputeConcatWithConstants",
      "BuildAndComputeGraphWithReshapeAsIntermediateNode",
      "BuildAndComputeGraphWithReshapeAsLastNode",
      "BuildAndComputeGraphWithSplitAndReshape",
      "BuildAndComputeGraphWithTransposeAndRelu",
      "BuildAndComputeGraphWithTransposeAndTwoOutputs",
      "BuildAndComputeGraphWithTransposeAndTwoReshape",
      "BuildAndComputeGraphWithTwoOutputs", "BuildAndComputeGraphWithTwoRelu",
      "BuildAndComputeGraphWithTwoReshape",
      "BuildAndComputeGraphWithTwoTranspose",
      "BuildAndComputeMultipleOperatorGemm",
      // "BuildAndComputeReluWithOnlyConstantInput",
      "BuildAndComputeReshapeConcatAndClamp",
      "BuildAndComputeSingleOperatorClamp",
      "BuildAndComputeSingleOperatorGruCell",
      "BuildAndComputeSingleOperatorGru",
      "BuildAndComputeSingleOperatorHardSigmoid",
      "BuildAndComputeSingleOperatorHardSwish",
      // "BuildAndComputeSingleOperatorLstmCell",
      // "BuildAndComputeSingleOperatorLstm",
      // "BuildAndComputeSingleOperatorResample2d",
      "BuildAndComputeSingleOperatorTanh",
      "BuildGemmWithReshapedConstantOperand", "BuildMaxPoolingAsFirstOperator",
      "BuildMaxPoolingAsSecondOperator", "BuildMaxPoolingAsThirdOperator",
      "BuildMultipleConstantsAppendingInputs",
      "BuildMultipleInputsAppendingConstants",
      "BuildSingleOperatorLayerNormalization",
      "BuildOneInputAndOneConstantOperand",
      // "FuseStandaloneActivationIntoBatchNormalization",
      // "FuseStandaloneActivationIntoConv2d",
      "FuseStandaloneActivationIntoElementWiseBinaryAdd",
      "FuseStandaloneActivationIntoGemm",
      // "FuseStandaloneActivationIntoInstanceNormalization",
      "FuseStandaloneActivationIntoLayerNormalization",
      "FuseStandaloneOperationsIntoMatmul",
      // "MultipleOutputsCanNotFuseStandaloneActivation",
  });
  if (!kSupportedTests.contains(current_test_name)) {
    GTEST_SKIP() << "Skipping test because the operator is not yet supported.";
  }

  SetUpBase();
}
#endif  // BUILDFLAG(WEBNN_USE_TFLITE) && !BUILDFLAG(IS_WIN)

void WebNNGraphImplBackendTest::SetUpBase() {
  webnn_test_environment_.BindWebNNContextProvider(
      provider_remote_.BindNewPipeAndPassReceiver());

  // Create the ContextImpl through context provider.
  base::test::TestFuture<mojom::CreateContextResultPtr> create_context_future;
  provider_remote_->CreateWebNNContext(
      mojom::CreateContextOptions::New(
          mojom::Device::kGpu,
          mojom::CreateContextOptions::PowerPreference::kDefault),
      create_context_future.GetCallback());
  mojom::CreateContextResultPtr create_context_result =
      create_context_future.Take();
  if (create_context_result->is_success()) {
    webnn_context_.Bind(
        std::move(create_context_result->get_success()->context_remote));
  }
  EXPECT_FALSE(create_context_result->is_error())
      << create_context_result->get_error()->message;
  EXPECT_TRUE(webnn_context_.is_bound());
}

void WebNNGraphImplBackendTest::TearDown() {
  webnn_context_.reset();
  EXPECT_TRUE(base::test::RunUntil([&]() { return true; }));
  // Give WebNNContext a chance to run disconnect.
  provider_remote_.reset();
}

mojo::AssociatedRemote<mojom::WebNNGraphBuilder>
WebNNGraphImplBackendTest::BindNewGraphBuilderRemote() {
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote;
  webnn_context_->CreateGraphBuilder(remote.BindNewEndpointAndPassReceiver());
  return remote;
}

struct FusibleOperationDescriptor {
  mojom::Operation::Tag kind;
  std::optional<float> alpha;
  std::optional<float> beta;
};

void BuildFusibleOperation(GraphInfoBuilder& builder,
                           const FusibleOperationDescriptor& operation,
                           OperandId input_operand_id,
                           OperandId output_operand_id) {
  switch (operation.kind) {
    case mojom::Operation::Tag::kElu: {
      CHECK(operation.alpha.has_value());
      builder.BuildElu(input_operand_id, output_operand_id, *operation.alpha);
      return;
    }
    case mojom::Operation::Tag::kHardSigmoid: {
      CHECK(operation.alpha.has_value());
      CHECK(operation.beta.has_value());
      builder.BuildHardSigmoid(input_operand_id, output_operand_id,
                               *operation.alpha, *operation.beta);
      return;
    }
    case mojom::Operation::Tag::kLeakyRelu: {
      CHECK(operation.alpha.has_value());
      builder.BuildLeakyRelu(input_operand_id, output_operand_id,
                             *operation.alpha);
      return;
    }
    case mojom::Operation::Tag::kLinear: {
      CHECK(operation.alpha.has_value());
      CHECK(operation.beta.has_value());
      builder.BuildLinear(input_operand_id, output_operand_id, *operation.alpha,
                          *operation.beta);
      return;
    }
    case mojom::Operation::Tag::kRelu:
      builder.BuildRelu(input_operand_id, output_operand_id);
      return;
    case mojom::Operation::Tag::kSigmoid:
      builder.BuildSigmoid(input_operand_id, output_operand_id);
      return;
    case mojom::Operation::Tag::kSoftplus:
      builder.BuildSoftplus(input_operand_id, output_operand_id);
      return;
    case mojom::Operation::Tag::kSoftsign:
      builder.BuildSoftsign(input_operand_id, output_operand_id);
      return;
    case mojom::Operation::Tag::kTanh:
      builder.BuildTanh(input_operand_id, output_operand_id);
      return;
    default:
      // TODO(crbug.com/345640552): Support fusing gelu.
      NOTREACHED();
  }
}

template <typename T>
struct BatchNormalizationTester {
  OperandInfo<T> input;
  OperandInfo<T> mean;
  OperandInfo<T> variance;
  std::optional<OperandInfo<T>> scale;
  std::optional<OperandInfo<T>> bias;
  struct BatchNormalizationAttributes {
    std::optional<OperandId> scale_operand_id;
    std::optional<OperandId> bias_operand_id;
    uint32_t axis = 1;
    float epsilon = 1e-5;
  };
  BatchNormalizationAttributes attributes;
  OperandInfo<T> output;

  void TestFusingOperation(
      WebNNGraphImplBackendTest& test,
      const FusibleOperationDescriptor& fusible_operation) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId mean_operand_id =
        builder.BuildInput("mean", mean.dimensions, mean.type);
    OperandId variance_operand_id =
        builder.BuildInput("variance", variance.dimensions, variance.type);
    OperandId intermediate_operand_id =
        builder.BuildIntermediateOperand(output.dimensions, output.type);
    if (scale.has_value()) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }

    builder.BuildBatchNormalization(
        input_operand_id, mean_operand_id, variance_operand_id,
        intermediate_operand_id, std::move(attributes));

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id,
                          output_operand_id);

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    named_inputs.insert({"mean", mean.values});
    named_inputs.insert({"variance", variance.values});
    if (scale.has_value()) {
      named_inputs.insert({"scale", scale->values});
    }
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }

    base::flat_map<std::string, std::vector<T>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing a standalone activation into
// batchNormalization automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoBatchNormalization) {
  {  // Test batchNormalization with 4-D input, default axis and activation =
    // linear.
    BatchNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 1, 3},
                  .values = {-1, 0, 1, 2, 3, 4}},
        .mean = {.type = OperandDataType::kFloat32,
                 .dimensions = {2},
                 .values = {0, 3}},
        .variance = {.type = OperandDataType::kFloat32,
                     .dimensions = {2},
                     .values = {1.0, 1.5}},
        .scale = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                    .dimensions = {2},
                                    .values = {1.0, 1.5}},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {2},
                                   .values = {0, 1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 1, 3},
                   .values = {-8.999950000374997, 1, 10.999950000374997,
                              -1.2474078892909666, 11, 23.24740788929097}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kLinear,
                                        .alpha = 10,
                                        .beta = 1});
  }
  {
    // Test batchNormalization with 4-D input with activation = hardsigmoid.
    BatchNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 1, 3},
                  .values = {-1, 0, 1, 2, 3, 4}},
        .mean = {.type = OperandDataType::kFloat32,
                 .dimensions = {2},
                 .values = {0, 3}},
        .variance = {.type = OperandDataType::kFloat32,
                     .dimensions = {2},
                     .values = {1.0, 1.5}},
        .scale = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                    .dimensions = {2},
                                    .values = {1.0, 1.5}},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {2},
                                   .values = {0, 1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 1, 3},
                   .values = {1, 1, 1, 1, 1, 1}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kHardSigmoid,
                                 .alpha = 1,
                                 .beta = 3});
  }
  {
    // Test batchNormalization with 4-D input with activation = relu.
    BatchNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 1, 3},
                  .values = {-1, 0, 1, 2, 3, 4}},
        .mean = {.type = OperandDataType::kFloat32,
                 .dimensions = {2},
                 .values = {0, 3}},
        .variance = {.type = OperandDataType::kFloat32,
                     .dimensions = {2},
                     .values = {1.0, 1.5}},
        .scale = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                    .dimensions = {2},
                                    .values = {1.0, 1.5}},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {2},
                                   .values = {0, 1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 1, 3},
                   .values = {0, 0, 0.9999950000374997, 0, 1,
                              2.224740788929097}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kRelu});
  }
  {
    // Test batchNormalization with 4-D input with activation = softplus.
    BatchNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 1, 3},
                  .values = {-100, -50, 100, 101, 102, 103}},
        .mean = {.type = OperandDataType::kFloat32,
                 .dimensions = {2},
                 .values = {0, 3}},
        .variance = {.type = OperandDataType::kFloat32,
                     .dimensions = {2},
                     .values = {1, 4}},
        .scale = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                    .dimensions = {2},
                                    .values = {1, 2}},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {2},
                                   .values = {0, 1}},
        .attributes = {.epsilon = 0},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 1, 3},
                   .values = {0, 0, 100, 99, 100, 101}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kSoftplus});
  }
  {
    // Test batchNormalization with 1-D input with activation = softsign.
    BatchNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {2},
                  .values = {-1, 1}},
        .mean = {.type = OperandDataType::kFloat32,
                 .dimensions = {2},
                 .values = {-1, 1}},
        .variance = {.type = OperandDataType::kFloat32,
                     .dimensions = {2},
                     .values = {1.0, 1.5}},
        .scale = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                    .dimensions = {2},
                                    .values = {1.0, 1.5}},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {2},
                                   .values = {0, 1}},
        .attributes = {.axis = 0},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {2},
                   .values = {0, 0.5}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kSoftsign});
  }
}

template <typename T>
struct Conv2dTester {
  mojom::Conv2d::Kind type;
  OperandInfo<T> input;
  OperandInfo<T> filter;
  struct Conv2dAttributes {
    std::vector<uint32_t> padding = {0, 0, 0, 0};
    std::vector<uint32_t> strides = {1, 1};
    std::vector<uint32_t> dilations = {1, 1};
    uint32_t groups = 1;
    std::optional<OperandInfo<T>> bias;
  };
  Conv2dAttributes attributes;
  OperandInfo<float> output;

  void TestFusingOperation(
      WebNNGraphImplBackendTest& test,
      const FusibleOperationDescriptor& fusible_operation) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId filter_operand_id = builder.BuildConstant(
        filter.dimensions, filter.type,
        base::as_byte_span(base::allow_nonunique_obj, filter.values));
    OperandId conv2d_output_operand_id =
        builder.BuildIntermediateOperand(output.dimensions, output.type);

    std::optional<OperandId> bias_operand_id;
    if (attributes.bias.has_value()) {
      bias_operand_id = builder.BuildConstant(
          attributes.bias->dimensions, attributes.bias->type,
          base::as_byte_span(base::allow_nonunique_obj,
                             attributes.bias->values));
    }

    builder.BuildConv2d(type, input_operand_id, filter_operand_id,
                        conv2d_output_operand_id, std::move(attributes),
                        bias_operand_id);

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    BuildFusibleOperation(builder, fusible_operation, conv2d_output_operand_id,
                          output_operand_id);

    base::flat_map<std::string, base::span<const T>> named_inputs;

    named_inputs.insert({"input", input.values});
    base::flat_map<std::string, std::vector<T>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing a standalone activation
// into conv2d automatically.
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoConv2d) {
  // Test conv2d with NCHW layout, float 32 data type, bias and fusing with elu
  // activation.
  {
    Conv2dTester<float>{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 3, 3},
                  .values = {0, 1, 2, 3, 4, 5, 6, 7, 8}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 1, 1},
                   .values = {1}},
        .attributes = {.bias =
                           OperandInfo<float>{.type = OperandDataType::kFloat32,
                                              .dimensions = {1},
                                              .values = {-5}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3},
                   .values = {-0.7946096424007316, -0.7853474888890126,
                              -0.7601703453057089, -0.6917317734107099,
                              -0.5056964470628461, 0, 1, 2, 3}}}
        .TestFusingOperation(
            *this, FusibleOperationDescriptor{
                       .kind = mojom::Operation::Tag::kElu, .alpha = 0.8});
  }
  // Test conv2d with NCHW layout, float 32 data type, bias and fusing with
  // leakyRelu activation.
  {
    Conv2dTester<float>{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 4, 4},
                  .values = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
                             15}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3},
                   .values = {1, 1, 1, 1, 1, 1, 1, 1, 1}},
        .attributes = {.bias =
                           OperandInfo<float>{.type = OperandDataType::kFloat32,
                                              .dimensions = {1},
                                              .values = {-60}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 2, 2},
                   .values = {-0.3, -0.12, 21, 30}}}
        .TestFusingOperation(
            *this,
            FusibleOperationDescriptor{
                .kind = mojom::Operation::Tag::kLeakyRelu, .alpha = 0.02});
  }
  // Test conv2d with NCHW layout, float 32 data type, fusing with bias and
  // linear activation.
  {
    Conv2dTester<float>{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5},
                  .values = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                             13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3},
                   .values = {1, 1, 1, 1, 1, 1, 1, 1, 1}},
        .attributes = {.padding = {1, 1, 1, 1},
                       .bias =
                           OperandInfo<float>{.type = OperandDataType::kFloat32,
                                              .dimensions = {1},
                                              .values = {1}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 5, 5},
                   .values = {1.13, 1.22, 1.28, 1.34, 1.25, 1.34, 1.55,
                              1.64, 1.73, 1.52, 1.64, 2,    2.09, 2.18,
                              1.82, 1.94, 2.45, 2.54, 2.63, 2.12, 1.73,
                              2.12, 2.18, 2.24, 1.85}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kLinear,
                                        .alpha = 0.01,
                                        .beta = 1});
  }
  // Test conv2d with NCHW layout, fusing with hardSigmoid activation.
  {
    Conv2dTester<float>{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5},
                  .values = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                             13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3},
                   .values = {1, 1, 1, 1, 1, 1, 1, 1, 1}},
        .attributes = {.padding = {1, 1, 1, 1},
                       .bias =
                           OperandInfo<float>{.type = OperandDataType::kFloat32,
                                              .dimensions = {1},
                                              .values = {1}}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 5, 5},
                   .values = {0,    0,    0, 0,    0,    0,    0, 0,    0,
                              0,    0,    0, 0.09, 0.18, 0,    0, 0.45, 0.54,
                              0.63, 0.12, 0, 0.12, 0.18, 0.24, 0}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kHardSigmoid,
                                 .alpha = 0.01,
                                 .beta = -1});
  }
  // Test conv2d with NCHW layout, fusing with sigmoid activation.
  {
    Conv2dTester<float>{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {2, 1, 3, 3},
                  .values = {0.7529087201709872, 0.7520291960017611,
                             0.594952773514815, 0.21631854011984264,
                             0.07589348976741683, 0.15106785419828572,
                             0.12124850358598671, 0.5364335407319905,
                             0.5937089927693522, 0.9910031422560608,
                             0.36309423611370084, 0.9289673923363004,
                             0.22727376737331384, 0.5414123970044269,
                             0.0844534212564596, 0.6765284772046276,
                             0.619325655574763, 0.39292160755260475}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {3, 1, 2, 2},
                   .values = {0.14543837927656278, 0.9671129790291346,
                              0.10836050336762582, 0.320230810822804,
                              0.6952692250382182, 0.5070913293589028,
                              0.0813970738017622, 0.5303338853508432,
                              0.30721364807734, 0.4324123448833208,
                              0.9849002194630809, 0.4281076188358701}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {2, 3, 2, 2},
                   .values = {0.7077627182006836, 0.6772933602333069,
                              0.5719422101974487, 0.5999819040298462,
                              0.7236577272415161, 0.7131744623184204,
                              0.618513286113739,  0.6196115612983704,
                              0.690409243106842,  0.6519721746444702,
                              0.6102449893951416, 0.704983651638031,
                              0.6666978597640991, 0.7382584810256958,
                              0.6959947943687439, 0.5874307155609131,
                              0.7647256255149841, 0.6926159262657166,
                              0.6934033632278442, 0.6633020043373108,
                              0.7144469618797302, 0.7469926476478577,
                              0.7747598886489868, 0.7273134589195251}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kSigmoid});
  }
  // Test conv2d with NCHW layout, float 32 data type, bias and fusing with
  // softplus activation.
  {
    Conv2dTester<float>{.type = mojom::Conv2d::Kind::kDirect,
                        .input = {.type = OperandDataType::kFloat32,
                                  .dimensions = {1, 1, 2, 2},
                                  .values = {40, 48, 56, 64}},
                        .filter = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 1, 1, 1},
                                   .values = {1}},
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 1, 2, 2},
                                   .values = {40, 48, 56, 64}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kSoftplus});
  }
  // Test conv2d with NCHW layout, float 32 data type, fusing with softsign
  // activation.
  {
    Conv2dTester<float>{.type = mojom::Conv2d::Kind::kDirect,
                        .input = {.type = OperandDataType::kFloat32,
                                  .dimensions = {1, 1, 3, 3},
                                  .values = {-3, -2, -1, -4, 0, 2, 1, 3, 4}},
                        .filter = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 1, 2, 2},
                                   .values = {1, 1, 1, 1}},
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 1, 2, 2},
                                   .values = {-0.9, -0.5, 0, 0.9}}}
        .TestFusingOperation(*this,
                             FusibleOperationDescriptor{
                                 .kind = mojom::Operation::Tag::kSoftsign});
  }
  // Test conv2d with NCHW layout, fusing with tanh activation.
  {
    Conv2dTester<float>{
        .type = mojom::Conv2d::Kind::kDirect,
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 5, 5},
                  .values = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                             13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}},
        .filter = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 3, 3},
                   .values = {0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
                              0.05}},
        .attributes = {.padding = {1, 1, 1, 1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 5, 5},
                   .values = {0.5370495669980353, 0.7818063576087741,
                              0.874053287886007,  0.9288576214547277,
                              0.8336546070121552, 0.9288576214547277,
                              0.9910074536781176, 0.9963341221150144,
                              0.9985079423323266, 0.9878803970168317,
                              0.9963341221150144, 0.9998996556706324,
                              0.9999592018254402, 0.9999834124992523,
                              0.9993931059399421, 0.9998171682522957,
                              0.9999988852198828, 0.9999995467640772,
                              0.9999998157280003, 0.999969775809118,
                              0.9985079423323266, 0.999969775809118,
                              0.9999834124992523, 0.9999908965525104,
                              0.9995503664595334}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kTanh});
  }
}

// I is the type of the inputs, both of which must be the same.
// O is the type of the output, which by default is the same as the input.
// Logical operators, however, have uint8_t (bool) as outputs.
template <typename I, typename O = I>
struct ElementWiseBinaryTester {
  OperandInfo<I> lhs;
  OperandInfo<I> rhs;
  mojom::ElementWiseBinary::Kind kind;
  OperandInfo<O> output;
  void Test(WebNNGraphImplBackendTest& helper) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        helper.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId lhs_operand_id =
        builder.BuildInput("lhs", lhs.dimensions, lhs.type);
    OperandId rhs_operand_id =
        builder.BuildInput("rhs", rhs.dimensions, rhs.type);
    auto graph_output_type = output.type;
#if BUILDFLAG(IS_MAC)
    if (output.type == OperandDataType::kUint8) {
      // macOS only supports FP16,FP32,DOUBLE,INT32 as outputs of graph.
      // For testing, we cast the output of the element-wise logical
      // operators to Int32 and set the graph output to Int32.
      graph_output_type = OperandDataType::kInt32;
    }
#endif  // BUILDFLAG(IS_MAC)
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, graph_output_type);
    OperandId element_wise_binary_output_operand_id = output_operand_id;
#if BUILDFLAG(IS_MAC)
    if (output.type == OperandDataType::kUint8) {
      element_wise_binary_output_operand_id = builder.BuildIntermediateOperand(
          output.dimensions, OperandDataType::kUint8);
    }
#endif  // BUILDFLAG(IS_MAC)
    builder.BuildElementWiseBinary(kind, lhs_operand_id, rhs_operand_id,
                                   element_wise_binary_output_operand_id);
#if BUILDFLAG(IS_MAC)
    if (output.type == OperandDataType::kUint8) {
      builder.BuildElementWiseUnary(mojom::ElementWiseUnary::Kind::kCast,
                                    element_wise_binary_output_operand_id,
                                    output_operand_id);
    }
#endif  // BUILDFLAG(IS_MAC)

    base::flat_map<std::string, base::span<const I>> named_inputs;
    named_inputs.insert({"lhs", lhs.values});
    named_inputs.insert({"rhs", rhs.values});
    base::flat_map<std::string, std::vector<O>> named_outputs =
        BuildAndCompute<O>(std::move(remote), builder.TakeGraphInfo(),
                           std::move(named_inputs));

#if BUILDFLAG(IS_MAC)
    if (output.type == OperandDataType::kUint8) {
      VerifyIsEqual(named_outputs["output"], output.ToInt32());
      return;
    }
#endif  // BUILDFLAG(IS_MAC)

    VerifyIsEqual(named_outputs["output"], output);
  }

  void TestFusingOperation(
      WebNNGraphImplBackendTest& test,
      const FusibleOperationDescriptor& fusible_operation) {
    // Now only binary add supports fusing standalone activation.
    CHECK_EQ(kind, mojom::ElementWiseBinary::Kind::kAdd);
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId lhs_operand_id =
        builder.BuildInput("lhs", lhs.dimensions, lhs.type);
    OperandId rhs_operand_id =
        builder.BuildInput("rhs", rhs.dimensions, rhs.type);
    OperandId intermediate_operand_id =
        builder.BuildIntermediateOperand(output.dimensions, output.type);
    builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                   lhs_operand_id, rhs_operand_id,
                                   intermediate_operand_id);

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id,
                          output_operand_id);

    base::flat_map<std::string, base::span<const I>> named_inputs;
    named_inputs.insert({"lhs", lhs.values});
    named_inputs.insert({"rhs", rhs.values});
    base::flat_map<std::string, std::vector<O>> named_outputs =
        BuildAndCompute<O>(test.context(), std::move(remote),
                           builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing a standalone activation
// into elementwise binary add automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoElementWiseBinaryAdd) {
  // Test add with linear activation.
  {
    ElementWiseBinaryTester<float>{
        .lhs = {.type = OperandDataType::kFloat32,
                .dimensions = {1, 2, 3, 1},
                .values = {1, 2, 3, 4, 5, 6}},
        .rhs = {.type = OperandDataType::kFloat32,
                .dimensions = {1, 2, 3, 1},
                .values = {0, 5.1, 4, 3, 2, 0}},
        .kind = mojom::ElementWiseBinary::Kind::kAdd,
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 3, 1},
                   .values = {11, 72, 71, 71, 71, 61}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kLinear,
                                        .alpha = 10,
                                        .beta = 1});
  }
  // Test add with relu activation.
  {
    ElementWiseBinaryTester<float>{.lhs = {.type = OperandDataType::kFloat32,
                                           .dimensions = {1, 2, 3, 1},
                                           .values = {1, 2, 3, 4, 5, 6}},
                                   .rhs = {.type = OperandDataType::kFloat32,
                                           .dimensions = {1, 2, 3, 1},
                                           .values = {-6, 5, 4, 3, 2, -7}},
                                   .kind = mojom::ElementWiseBinary::Kind::kAdd,
                                   .output = {.type = OperandDataType::kFloat32,
                                              .dimensions = {1, 2, 3, 1},
                                              .values = {0, 7, 7, 7, 7, 0}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kRelu});
  }
}

// Test building and computing a graph in the following topology.
//         [input]
//            |
//          split
//        /       \
//   [output1]  reshape
//                 |
//             [output2]
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithSplitAndReshape) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {2, 5}, OperandDataType::kFloat32);
  OperandId output1_operand_id =
      builder.BuildOutput("output1", {2, 2}, OperandDataType::kFloat32);
  OperandId split_operand_id =
      builder.BuildIntermediateOperand({2, 3}, OperandDataType::kFloat32);
  builder.BuildSplit(input_operand_id, {output1_operand_id, split_operand_id},
                     1);

  OperandId output_operand_id =
      builder.BuildOutput("output2", {3, 2}, OperandDataType::kFloat32);
  builder.BuildReshape(split_operand_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  // [[ 1  2  3  4  5]
  //  [ 6  7  8  9 10]] with shape (2, 5)
  std::vector<float> input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  // [[1  2]
  //  [6  7]] with shape (2, 2)
  VerifyFloatDataIsEqual(named_outputs["output1"], {1, 2, 6, 7});
  // [[3  4]
  //  [5  8]
  //  [9  10]] with shape (3, 2)
  VerifyFloatDataIsEqual(named_outputs["output2"], {3, 4, 5, 8, 9, 10});
}

template <typename T>
struct UnaryOperatorTester {
  mojom::Operation::Tag tag;
  OperandInfo<T> input;
  std::optional<float> clamp_min_value;
  std::optional<float> clamp_max_value;
  std::optional<float> hard_sigmoid_alpha;
  std::optional<float> hard_sigmoid_beta;
  std::optional<float> elu_alpha;
  std::optional<float> leaky_relu_alpha;
  std::optional<float> linear_alpha;
  std::optional<float> linear_beta;
  OperandInfo<T> output;
  void Test(WebNNGraphImplBackendTest& test,
            BuildAndComputeExpectation expectation =
                BuildAndComputeExpectation::kSuccess) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    switch (tag) {
      case mojom::Operation::Tag::kClamp:
        CHECK(clamp_min_value);
        CHECK(clamp_max_value);
        builder.BuildClamp(input_operand_id, output_operand_id,
                           clamp_min_value.value(), clamp_max_value.value());
        break;
      case mojom::Operation::Tag::kElu:
        CHECK(elu_alpha);
        builder.BuildElu(input_operand_id, output_operand_id,
                         elu_alpha.value());
        break;
      case mojom::Operation::Tag::kHardSigmoid:
        builder.BuildHardSigmoid(input_operand_id, output_operand_id,
                                 hard_sigmoid_alpha, hard_sigmoid_beta);
        break;
      case mojom::Operation::Tag::kHardSwish:
        builder.BuildHardSwish(input_operand_id, output_operand_id);
        break;
      case mojom::Operation::Tag::kLeakyRelu:
        CHECK(leaky_relu_alpha);
        builder.BuildLeakyRelu(input_operand_id, output_operand_id,
                               leaky_relu_alpha.value());
        break;
      case mojom::Operation::Tag::kLinear:
        CHECK(linear_alpha);
        CHECK(linear_beta);
        builder.BuildLinear(input_operand_id, output_operand_id,
                            linear_alpha.value(), linear_beta.value());
        break;
      case mojom::Operation::Tag::kRelu:
        builder.BuildRelu(input_operand_id, output_operand_id);
        break;
      case mojom::Operation::Tag::kSigmoid:
        builder.BuildSigmoid(input_operand_id, output_operand_id);
        break;
      case mojom::Operation::Tag::kSoftplus:
        builder.BuildSoftplus(input_operand_id, output_operand_id);
        break;
      case mojom::Operation::Tag::kSoftsign:
        builder.BuildSoftsign(input_operand_id, output_operand_id);
        break;
      case mojom::Operation::Tag::kTanh:
        builder.BuildTanh(input_operand_id, output_operand_id);
        break;
      default:
        NOTREACHED();
    }

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute(
        test.context(), std::move(remote), builder.TakeGraphInfo(),
        std::move(named_inputs), expectation);

    if (expectation == BuildAndComputeExpectation::kSuccess) {
      VerifyIsEqual(named_outputs["output"], output);
    }
  }
};

// Test building and computing a graph with single operator clamp.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorClamp) {
  {
    // Test clamp for 0-D scalar input.
    UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kClamp,
                               .input = {.type = OperandDataType::kFloat32,
                                         .dimensions = {},
                                         .values = {24}},
                               .clamp_min_value = 0,
                               .clamp_max_value = 3,
                               .output = {.type = OperandDataType::kFloat32,
                                          .dimensions = {},
                                          .values = {3}}}
        .Test(*this);
  }
}

// Test building and computing a graph with single operator hardSigmoid.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSigmoid) {
  {
    // Test sigmoid for 0-D scalar input.
    UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kHardSigmoid,
                               .input = {.type = OperandDataType::kFloat32,
                                         .dimensions = {},
                                         .values = {24}},
                               .hard_sigmoid_alpha = 0.1,
                               .hard_sigmoid_beta = 3,
                               .output = {.type = OperandDataType::kFloat32,
                                          .dimensions = {},
                                          .values = {1}}}
        .Test(*this);
  }
}

// Test building and computing a graph with single operator hardSwish.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorHardSwish) {
  // Test hardSwish with a 0-D scalar input.
  {
    UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kHardSwish,
                               .input = {.type = OperandDataType::kFloat32,
                                         .dimensions = {},
                                         .values = {7.0}},
                               .output = {.type = OperandDataType::kFloat32,
                                          .dimensions = {},
                                          .values = {7.0}}}
        .Test(*this);
  }
}

// Test building and computing a graph with single operator tanh.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorTanh) {
  // Test tanh with a 0-D scalar input.
  {
    UnaryOperatorTester<float>{.tag = mojom::Operation::Tag::kTanh,
                               .input = {.type = OperandDataType::kFloat32,
                                         .dimensions = {},
                                         .values = {-1}},
                               .output = {.type = OperandDataType::kFloat32,
                                          .dimensions = {},
                                          .values = {-0.76159418}}}
        .Test(*this);
  }
}

// Test building and computing a graph with two relu operators.
//    [input]
//       |
//      relu1
//       |
//      relu2
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoRelu) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
  OperandId relu1_output_id =
      builder.BuildIntermediateOperand({1, 2, 3, 4}, OperandDataType::kFloat32);
  builder.BuildRelu(input_operand_id, relu1_output_id);
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 2, 3, 4}, OperandDataType::kFloat32);
  builder.BuildRelu(relu1_output_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {-1, -2,  -3,  -4,  -5, -6, -7, -8,
                                   -9, -10, -11, -12, 13, 14, 15, 16,
                                   17, 18,  19,  20,  21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"],
                         {0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                          13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
}

// Test building and computing a graph with two operators (reshape as the
// last node).
//    [input]
//       |
//      relu
//       |
//     reshape
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithReshapeAsLastNode) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
  OperandId relu_output_id =
      builder.BuildIntermediateOperand({1, 2, 3, 4}, OperandDataType::kFloat32);
  builder.BuildRelu(input_operand_id, relu_output_id);
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 1, 6, 4}, OperandDataType::kFloat32);
  builder.BuildReshape(relu_output_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1,  2,  3,  4,  5,  6,  7,  8,
                                   9,  10, 11, 12, 13, 14, 15, 16,
                                   17, 18, 19, 20, 21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], input_data);
}

// Test building and computing a graph with two operators (reshape as an
// intermediate node).
//    [input]
//       |
//    reshape
//       |
//      relu
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeGraphWithReshapeAsIntermediateNode) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
  OperandId reshape_output_id =
      builder.BuildIntermediateOperand({1, 1, 6, 4}, OperandDataType::kFloat32);
  builder.BuildReshape(input_operand_id, reshape_output_id);
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 1, 6, 4}, OperandDataType::kFloat32);
  builder.BuildRelu(reshape_output_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1,  2,  3,  4,  5,  6,  7,  8,
                                   9,  10, 11, 12, 13, 14, 15, 16,
                                   17, 18, 19, 20, 21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], input_data);
}

// Test building and computing a graph with two reshape operators
//    [input]
//       |
//    reshape1
//       |
//    reshape2
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoReshape) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
  OperandId reshape_output_id =
      builder.BuildIntermediateOperand({1, 1, 6, 4}, OperandDataType::kFloat32);
  builder.BuildReshape(input_operand_id, reshape_output_id);
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 2, 3, 4}, OperandDataType::kFloat32);
  builder.BuildReshape(reshape_output_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1,  2,  3,  4,  5,  6,  7,  8,
                                   9,  10, 11, 12, 13, 14, 15, 16,
                                   17, 18, 19, 20, 21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], input_data);
}

// Test building and computing a graph with two operators and two outputs
//      [input]
//       /   \
//  reshape   relu
//     |        |
// [output1] [output2]
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoOutputs) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);
  OperandId output1_operand_id =
      builder.BuildOutput("output1", {1, 1, 6, 4}, OperandDataType::kFloat32);
  builder.BuildReshape(input_operand_id, output1_operand_id);
  OperandId output2_operand_id =
      builder.BuildOutput("output2", {1, 2, 3, 4}, OperandDataType::kFloat32);
  builder.BuildRelu(input_operand_id, output2_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {-1, -2,  -3,  -4,  -5, -6, -7, -8,
                                   -9, -10, -11, -12, 13, 14, 15, 16,
                                   17, 18,  19,  20,  21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output1"],
                         {-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12,
                          13, 14, 15, 16, 17, 18, 19, 20, 21, 22,  23,  24});
  VerifyFloatDataIsEqual(named_outputs["output2"],
                         {0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                          13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
}

struct GemmAttributes {
  std::optional<OperandId> c_operand_id;
  // TODO(crbug.com/40206287): Add test cases for below attributes.
  float alpha = 1.0;
  float beta = 1.0;
  bool a_transpose = false;
  bool b_transpose = false;
};

template <typename T>
struct GemmTester {
  OperandInfo<T> input_a;
  OperandInfo<T> input_b;
  std::optional<OperandInfo<T>> input_c;
  GemmAttributes attributes;
  OperandInfo<float> output;

  void TestFusingOperation(
      WebNNGraphImplBackendTest& test,
      const FusibleOperationDescriptor& fusible_operation) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_a_operand_id =
        builder.BuildInput("input_a", input_a.dimensions, input_a.type);
    OperandId input_b_operand_id =
        builder.BuildInput("input_b", input_b.dimensions, input_b.type);
    OperandId intermediate_operand_id =
        builder.BuildIntermediateOperand(output.dimensions, output.type);
    if (input_c.has_value()) {
      attributes.c_operand_id =
          builder.BuildInput("input_c", input_c->dimensions, input_c->type);
    }

    builder.BuildGemm(input_a_operand_id, input_b_operand_id,
                      intermediate_operand_id, std::move(attributes));

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id,
                          output_operand_id);

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input_a", input_a.values});
    named_inputs.insert({"input_b", input_b.values});
    if (input_c.has_value()) {
      named_inputs.insert({"input_c", input_c->values});
    }
    base::flat_map<std::string, std::vector<float>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing a standalone activation
// into gemm automatically.
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneActivationIntoGemm) {
  // Test gemm without a third input, activation = linear.
  {
    GemmTester<float>{.input_a = {.type = OperandDataType::kFloat32,
                                  .dimensions = {2, 2},
                                  .values = {1, 2, 3, 4}},
                      .input_b = {.type = OperandDataType::kFloat32,
                                  .dimensions = {2, 2},
                                  .values = {1, 2, 3, 4}},
                      .output = {.type = OperandDataType::kFloat32,
                                 .dimensions = {2, 2},
                                 .values = {71, 101, 151, 221}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kLinear,
                                        .alpha = 10,
                                        .beta = 1});
  }

  // Test gemm with a third input, activation = relu.
  {
    GemmTester<float>{
        .input_a = {.type = OperandDataType::kFloat32,
                    .dimensions = {2, 2},
                    .values = {1, 2, 3, -4}},
        .input_b = {.type = OperandDataType::kFloat32,
                    .dimensions = {2, 2},
                    .values = {1, 2, 3, 4}},
        .input_c = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                      .dimensions = {2, 2},
                                      .values = {1, 1, 1, 1}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {2, 2},
                   .values = {8, 11, 0, 0}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kRelu});
  }
}

template <typename T>
struct GruTester {
  struct GruAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    std::optional<OperandId> initial_hidden_state_operand_id;
    bool reset_after = true;
    bool return_sequence = false;
    mojom::RecurrentNetworkDirection direction =
        mojom::RecurrentNetworkDirection::kForward;
    mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn;
    std::vector<mojom::RecurrentNetworkActivation> activations{
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh};
  };

  OperandInfo<T> input;
  OperandInfo<T> weight;
  OperandInfo<T> recurrent_weight;
  uint32_t steps;
  uint32_t hidden_size;
  std::optional<OperandInfo<T>> bias;
  std::optional<OperandInfo<T>> recurrent_bias;
  std::optional<OperandInfo<T>> initial_hidden_state;
  GruAttributes attributes;
  std::vector<OperandInfo<T>> outputs;

  void Test(WebNNGraphImplBackendTest& helper,
            BuildAndComputeExpectation expectation =
                BuildAndComputeExpectation::kSuccess) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        helper.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }
    if (initial_hidden_state.has_value()) {
      attributes.initial_hidden_state_operand_id = builder.BuildConstant(
          initial_hidden_state->dimensions, initial_hidden_state->type,
          base::as_byte_span(base::allow_nonunique_obj,
                             initial_hidden_state->values));
    }

    std::vector<OperandId> output_operand_ids;
    output_operand_ids.reserve(outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
      const auto& output = outputs[i];
      output_operand_ids.push_back(builder.BuildOutput(
          "output" + base::NumberToString(i), output.dimensions, output.type));
    }

    builder.BuildGru(input_operand_id, weight_operand_id,
                     recurrent_weight_operand_id, std::move(output_operand_ids),
                     steps, hidden_size, std::move(attributes));

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    named_inputs.insert({"weight", weight.values});
    named_inputs.insert({"recurrentWeight", recurrent_weight.values});
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }
    if (recurrent_bias.has_value()) {
      named_inputs.insert({"recurrentBias", recurrent_bias->values});
    }

    base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute(
        helper.context(), std::move(remote), builder.TakeGraphInfo(),
        std::move(named_inputs), expectation);

    if (expectation == BuildAndComputeExpectation::kSuccess) {
      for (size_t i = 0; i < outputs.size(); ++i) {
        VerifyIsEqual(named_outputs["output" + base::NumberToString(i)],
                      outputs[i]);
      }
    }
  }
};

// Test building and computing a graph with single operator gru.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGru) {
  // Test gru without bias and initial hidden state.
  {
    const uint32_t steps = 1;
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    const uint32_t num_directions = 1;
    GruTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size},
                   .values = std::vector<float>(
                       num_directions * 3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size},
                             .values = std::vector<float>(
                                 num_directions * 3 * hidden_size * hidden_size,
                                 1)},
        .steps = steps,
        .hidden_size = hidden_size,
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size},
                     .values = {-30., -30., -30., -30., -30., -210., -210.,
                                -210., -210., -210., -552., -552., -552., -552.,
                                -552.}}}}
        .Test(*this);
  }
  // Test gru with number directions = 2.
  {
    const uint32_t steps = 1;
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    const uint32_t num_directions = 2;
    GruTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size},
                   .values = std::vector<float>(
                       num_directions * 3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size},
                             .values = std::vector<float>(
                                 num_directions * 3 * hidden_size * hidden_size,
                                 1)},
        .steps = steps,
        .hidden_size = hidden_size,
        .attributes =
            {.direction = mojom::RecurrentNetworkDirection::kBoth,
             .activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size},
                     .values = {-30.,  -30.,  -30.,  -30.,  -30.,  -210.,
                                -210., -210., -210., -210., -552., -552.,
                                -552., -552., -552., -30.,  -30.,  -30.,
                                -30.,  -30.,  -210., -210., -210., -210.,
                                -210., -552., -552., -552., -552., -552.}}}}
        .Test(*this);
  }
  // Test gru with steps = 2.
  {
    const uint32_t steps = 2;
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    const uint32_t num_directions = 2;
    GruTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
                             9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size},
                   .values = std::vector<float>(
                       num_directions * 3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size},
                             .values = std::vector<float>(
                                 num_directions * 3 * hidden_size * hidden_size,
                                 1)},
        .steps = steps,
        .hidden_size = hidden_size,
        .attributes =
            {.direction = mojom::RecurrentNetworkDirection::kBoth,
             .activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size},
                     .values = {6.,  6.,  6.,  6.,  6.,  15., 15., 15.,
                                15., 15., 24., 24., 24., 24., 24., 6.,
                                6.,  6.,  6.,  6.,  15., 15., 15., 15.,
                                15., 24., 24., 24., 24., 24.}}}}
        .Test(*this);
  }
  // Test gru with bias and recurrentbias.
  {
    const uint32_t steps = 1;
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    const uint32_t num_directions = 1;
    GruTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size},
                   .values = std::vector<float>(
                       num_directions * 3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size},
                             .values = std::vector<float>(
                                 num_directions * 3 * hidden_size * hidden_size,
                                 1)},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {num_directions, 3 * hidden_size},
                               .values = std::vector<float>(
                                   num_directions * 3 * hidden_size, 1)},
        .recurrent_bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {num_directions, 3 * hidden_size},
                               .values = std::vector<float>(
                                   num_directions * 3 * hidden_size, 0)},
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size},
                     .values = {-42., -42., -42., -42., -42., -240., -240.,
                                -240., -240., -240., -600., -600., -600., -600.,
                                -600.}}}}
        .Test(*this);
  }
  // Test gru with bias and initial hidden state.
  {
    const uint32_t steps = 1;
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    const uint32_t num_directions = 1;
    GruTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size},
                   .values = std::vector<float>(
                       num_directions * 3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size},
                             .values = std::vector<float>(
                                 num_directions * 3 * hidden_size * hidden_size,
                                 1)},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {num_directions, 3 * hidden_size},
                               .values = std::vector<float>(
                                   num_directions * 3 * hidden_size, 1)},
        .initial_hidden_state =
            OperandInfo<float>{
                .type = OperandDataType::kFloat32,
                .dimensions = {num_directions, batch_size, hidden_size},
                .values = std::vector<float>(
                    num_directions * batch_size * hidden_size, 1)},
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {num_directions, batch_size, hidden_size},
                     .values = {-725., -725., -725., -725., -725., -2399.,
                                -2399., -2399., -2399., -2399., -5045., -5045.,
                                -5045., -5045., -5045.}}}}
        .Test(*this);
  }
  // Test gru with return_sequence = true;
  {
    const uint32_t steps = 1;
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    const uint32_t num_directions = 1;
    GruTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {num_directions, 3 * hidden_size, input_size},
                   .values = std::vector<float>(
                       num_directions * 3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {num_directions, 3 * hidden_size,
                                            hidden_size},
                             .values = std::vector<float>(
                                 num_directions * 3 * hidden_size * hidden_size,
                                 1)},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {num_directions, 3 * hidden_size},
                               .values = std::vector<float>(
                                   num_directions * 3 * hidden_size, 1)},
        .recurrent_bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {num_directions, 3 * hidden_size},
                               .values = std::vector<float>(
                                   num_directions * 3 * hidden_size, 0)},
        .initial_hidden_state =
            OperandInfo<float>{
                .type = OperandDataType::kFloat32,
                .dimensions = {num_directions, batch_size, hidden_size},
                .values = std::vector<float>(
                    num_directions * batch_size * hidden_size, 1)},
        .attributes =
            {.return_sequence = true,
             .activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs =
            {{.type = OperandDataType::kFloat32,
              .dimensions = {num_directions, batch_size, hidden_size},
              .values = {-725., -725., -725., -725., -725., -2399., -2399.,
                         -2399., -2399., -2399., -5045., -5045., -5045., -5045.,
                         -5045.}},
             {.type = OperandDataType::kFloat32,
              .dimensions = {steps, num_directions, batch_size, hidden_size},
              .values = {-725., -725., -725., -725., -725., -2399., -2399.,
                         -2399., -2399., -2399., -5045., -5045., -5045., -5045.,
                         -5045.}}}}
        .Test(*this);
  }
}

// TODO(https://issues.chromium.org/issues/331250158): Delete the test cases
// after the WPT conformance tests are completed.
template <typename T>
struct GruCellTester {
  struct GruCellAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    bool reset_after = true;
    mojom::GruWeightLayout layout = mojom::GruWeightLayout::kZrn;
    std::vector<mojom::RecurrentNetworkActivation> activations{
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh};
  };

  OperandInfo<T> input;
  OperandInfo<T> weight;
  OperandInfo<T> recurrent_weight;
  OperandInfo<T> hidden_state;
  uint32_t hidden_size;
  std::optional<OperandInfo<T>> bias;
  std::optional<OperandInfo<T>> recurrent_bias;
  GruCellAttributes attributes;
  OperandInfo<T> output;

  void Test(WebNNGraphImplBackendTest& helper,
            BuildAndComputeExpectation expectation =
                BuildAndComputeExpectation::kSuccess) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        helper.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);
    OperandId hidden_state_operand_id = builder.BuildInput(
        "hiddenState", hidden_state.dimensions, hidden_state.type);

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);

    builder.BuildGruCell(input_operand_id, weight_operand_id,
                         recurrent_weight_operand_id, hidden_state_operand_id,
                         output_operand_id, hidden_size, std::move(attributes));

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    named_inputs.insert({"weight", weight.values});
    named_inputs.insert({"recurrentWeight", recurrent_weight.values});
    named_inputs.insert({"hiddenState", hidden_state.values});
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }
    if (recurrent_bias.has_value()) {
      named_inputs.insert({"recurrentBias", recurrent_bias->values});
    }

    base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute(
        helper.context(), std::move(remote), builder.TakeGraphInfo(),
        std::move(named_inputs), expectation);

    if (expectation == BuildAndComputeExpectation::kSuccess) {
      VerifyIsEqual(named_outputs["output"], output);
    }
  }
};

// Test building and computing a graph with single operator gruCell.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorGruCell) {
  // Test gruCell without bias and initial hidden state.
  {
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    GruCellTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {3 * hidden_size, input_size},
                   .values =
                       std::vector<float>(3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {3 * hidden_size, hidden_size},
                             .values = std::vector<float>(
                                 3 * hidden_size * hidden_size, 1)},
        .hidden_state = {.type = OperandDataType::kFloat32,
                         .dimensions = {batch_size, hidden_size},
                         .values =
                             std::vector<float>(batch_size * hidden_size, 0)},
        .hidden_size = hidden_size,
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {batch_size, hidden_size},
                   .values = {-30., -30., -30., -30., -30., -210., -210., -210.,
                              -210., -210., -552., -552., -552., -552., -552.}}}
        .Test(*this);
  }
  // Test gruCell with bias and recurrentbias.
  {
    const uint32_t batch_size = 3;
    const uint32_t input_size = 3;
    const uint32_t hidden_size = 5;
    GruCellTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {batch_size, input_size},
                  .values = {1, 2, 3, 4, 5, 6, 7, 8, 9}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {3 * hidden_size, input_size},
                   .values =
                       std::vector<float>(3 * hidden_size * input_size, 1)},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {3 * hidden_size, hidden_size},
                             .values = std::vector<float>(
                                 3 * hidden_size * hidden_size, 1)},
        .hidden_state = {.type = OperandDataType::kFloat32,
                         .dimensions = {batch_size, hidden_size},
                         .values =
                             std::vector<float>(batch_size * hidden_size, 0)},
        .hidden_size = hidden_size,
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {3 * hidden_size},
                                   .values =
                                       std::vector<float>(3 * hidden_size, 1)},
        .recurrent_bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                             .dimensions = {3 * hidden_size},
                                             .values = std::vector<float>(
                                                 3 * hidden_size, 0)},
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {batch_size, hidden_size},
                   .values = {-42., -42., -42., -42., -42., -240., -240., -240.,
                              -240., -240., -600., -600., -600., -600., -600.}}}
        .Test(*this);
  }
}

// Test building and computing a graph with three gemm operations.
//    [input_a] [input_b] [input_a] [input_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeMultipleOperatorGemm) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_a_operand_id, input_b_operand_id,
                    intermediate_1_operand_id, GemmAttributes());
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_a_operand_id, input_b_operand_id,
                    intermediate_2_operand_id, GemmAttributes());
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
                    output_operand_id, GemmAttributes());

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_a_data = {1, 2, 3, 4};
  named_inputs.insert({"input_a", input_a_data});
  std::vector<float> input_b_data = {1, 1, 1, 1};
  named_inputs.insert({"input_b", input_b_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], {30, 30, 70, 70});
}

// Test building and computing a graph with one input and one constant.
TEST_F(WebNNGraphImplBackendTest, BuildOneInputAndOneConstantOperand) {
  // Build the mojom graph info.
  std::vector<float> constant_data = {5, 6, 7, 8};
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_a_operand_id, input_b_operand_id, output_operand_id,
                    GemmAttributes());

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_a_data = {1, 1, 1, 1};
  named_inputs.insert({"input_a", input_a_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], {12, 14, 12, 14});
}

template <typename T>
struct InstanceNormalizationTester {
  OperandInfo<T> input;
  std::optional<OperandInfo<T>> scale;
  std::optional<OperandInfo<T>> bias;
  struct InstanceNormalizationAttributes {
    std::optional<OperandId> scale_operand_id;
    std::optional<OperandId> bias_operand_id;
    float epsilon = 1e-5;
  };
  InstanceNormalizationAttributes attributes;
  OperandInfo<T> output;

  void TestFusingOperation(
      WebNNGraphImplBackendTest& test,
      const FusibleOperationDescriptor& fusible_operation) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId intermediate_operand_id =
        builder.BuildIntermediateOperand(output.dimensions, output.type);
    if (scale.has_value()) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }

    builder.BuildInstanceNormalization(
        input_operand_id, intermediate_operand_id, std::move(attributes));

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id,
                          output_operand_id);

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    if (scale.has_value()) {
      named_inputs.insert({"scale", scale->values});
    }
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }
    base::flat_map<std::string, std::vector<T>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing a standalone activation into
// instanceNormalization automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoInstanceNormalization) {
  {
    // Test instanceNormalization with 4-D input with default scale and bias and
    // activation = relu.
    InstanceNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 1, 3},
                  .values = {1, 2, 3, 4, 5, 6}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 1, 3},
                   .values = {0, 0, 1.2247356859083902, 0, 0,
                              1.2247356859083902}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kRelu});
  }
}

template <typename T>
struct LayerNormalizationTester {
  OperandInfo<T> input;
  std::optional<OperandInfo<T>> scale;
  std::optional<OperandInfo<T>> bias;
  struct LayerNormalizationAttributes {
    std::optional<OperandId> scale_operand_id;
    std::optional<OperandId> bias_operand_id;
    std::vector<uint32_t> axes;
    float epsilon = 1e-5;
  };
  LayerNormalizationAttributes attributes;
  OperandInfo<T> output;

  void Test(WebNNGraphImplBackendTest& test,
            BuildAndComputeExpectation expectation =
                BuildAndComputeExpectation::kSuccess) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    if (scale.has_value()) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }

    builder.BuildLayerNormalization(input_operand_id, output_operand_id,
                                    std::move(attributes));

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    if (scale.has_value()) {
      named_inputs.insert({"scale", scale->values});
    }
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }
    base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute(
        test.context(), std::move(remote), builder.TakeGraphInfo(),
        std::move(named_inputs), expectation);

    if (expectation == BuildAndComputeExpectation::kSuccess) {
      VerifyIsEqual(named_outputs["output"], output);
    }
  }

  void TestFusingOperation(
      WebNNGraphImplBackendTest& test,
      const FusibleOperationDescriptor& fusible_operation) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId intermediate_operand_id =
        builder.BuildIntermediateOperand(output.dimensions, output.type);
    if (scale.has_value()) {
      attributes.scale_operand_id =
          builder.BuildInput("scale", scale->dimensions, scale->type);
    }
    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }

    builder.BuildLayerNormalization(input_operand_id, intermediate_operand_id,
                                    std::move(attributes));

    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    BuildFusibleOperation(builder, fusible_operation, intermediate_operand_id,
                          output_operand_id);

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    if (scale.has_value()) {
      named_inputs.insert({"scale", scale->values});
    }
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }
    base::flat_map<std::string, std::vector<T>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing a standalone activation into
// layerNormalization automatically.
TEST_F(WebNNGraphImplBackendTest,
       FuseStandaloneActivationIntoLayerNormalization) {
  {
    // Test layerNormalization with 1-D input with axes = [0] and default scale
    // and bias and activation = relu.
    LayerNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {5},
                  .values = {0, 1, 2, 3, 4}},
        .attributes = {.axes = {0}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {5},
                   .values = {0, 0, 0, 0.7071050134262237, 1.4142100268524473}}}
        .TestFusingOperation(*this, FusibleOperationDescriptor{
                                        .kind = mojom::Operation::Tag::kRelu});
  }
}

// Test building and computing a graph with single operator
// layerNormalization.
TEST_F(WebNNGraphImplBackendTest, BuildSingleOperatorLayerNormalization) {
  {
    // Test layerNormalization with a scalar input with default scale and bias.
    LayerNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {},
                  .values = {5}},
        .attributes = {.axes = {}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {},
                   .values = {0}}}
        .Test(*this);
  }
  {
    // Test layerNormalization with 6-D input with permuted axes = [4, 1, 2].
    LayerNormalizationTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 2, 1, 3, 2, 1},
                  .values = {-4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7}},
        .scale = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                    .dimensions = {2, 2, 1},
                                    .values = {0.5, 0, 1, -0.5}},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {2, 2, 1},
                                   .values = {0.1, 0.2, 0.3, 0.4}},
        .attributes = {.axes = {4, 1, 2}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 2, 1, 3, 2, 1},
                   .values = {-0.47539614454389156, -0.5219944922055593,
                              -0.47539614454389156, -0.5219944922055593,
                              -0.47539614454389156, -0.5219944922055593, 0.2,
                              -0.17539614454389152, 0.2, -0.17539614454389152,
                              0.2, -0.17539614454389152}}}
        .Test(*this);
  }
}

template <typename T>
struct LstmTester {
  OperandInfo<T> input;
  OperandInfo<T> weight;
  OperandInfo<T> recurrent_weight;
  uint32_t steps;
  uint32_t hidden_size;
  std::optional<OperandInfo<T>> bias;
  std::optional<OperandInfo<T>> recurrent_bias;
  std::optional<OperandInfo<T>> peephole_weight;
  std::optional<OperandInfo<T>> initial_hidden_state;
  std::optional<OperandInfo<T>> initial_cell_state;
  struct LstmAttributes {
    std::optional<OperandId> bias_operand_id;
    std::optional<OperandId> recurrent_bias_operand_id;
    std::optional<OperandId> peephole_weight_operand_id;
    std::optional<OperandId> initial_hidden_state_operand_id;
    std::optional<OperandId> initial_cell_state_operand_id;
    bool return_sequence = false;
    mojom::RecurrentNetworkDirection direction =
        mojom::RecurrentNetworkDirection::kForward;
    mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg;
    std::vector<mojom::RecurrentNetworkActivation> activations{
        mojom::RecurrentNetworkActivation::kSigmoid,
        mojom::RecurrentNetworkActivation::kTanh,
        mojom::RecurrentNetworkActivation::kTanh};
  };
  LstmAttributes attributes;
  std::vector<OperandInfo<T>> outputs;

  void Test(WebNNGraphImplBackendTest& helper,
            BuildAndComputeExpectation expectation =
                BuildAndComputeExpectation::kSuccess) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        helper.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId weight_operand_id =
        builder.BuildInput("weight", weight.dimensions, weight.type);
    OperandId recurrent_weight_operand_id = builder.BuildInput(
        "recurrentWeight", recurrent_weight.dimensions, recurrent_weight.type);

    if (bias.has_value()) {
      attributes.bias_operand_id =
          builder.BuildInput("bias", bias->dimensions, bias->type);
    }
    if (recurrent_bias.has_value()) {
      attributes.recurrent_bias_operand_id = builder.BuildInput(
          "recurrentBias", recurrent_bias->dimensions, recurrent_bias->type);
    }
    if (peephole_weight.has_value()) {
      attributes.peephole_weight_operand_id = builder.BuildInput(
          "peepholeWeight", peephole_weight->dimensions, peephole_weight->type);
    }
    if (initial_hidden_state.has_value()) {
      attributes.initial_hidden_state_operand_id = builder.BuildInput(
          "initialHiddenState", initial_hidden_state->dimensions,
          initial_hidden_state->type);
    }
    if (initial_cell_state.has_value()) {
      attributes.initial_cell_state_operand_id =
          builder.BuildInput("initialCellState", initial_cell_state->dimensions,
                             initial_cell_state->type);
    }

    std::vector<OperandId> output_operand_ids;
    output_operand_ids.reserve(outputs.size());
    for (size_t i = 0; i < outputs.size(); ++i) {
      const auto& output = outputs[i];
      output_operand_ids.push_back(builder.BuildOutput(
          "output" + base::NumberToString(i), output.dimensions, output.type));
    }

    builder.BuildLstm(input_operand_id, weight_operand_id,
                      recurrent_weight_operand_id,
                      std::move(output_operand_ids), steps, hidden_size,
                      std::move(attributes));

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    named_inputs.insert({"weight", weight.values});
    named_inputs.insert({"recurrentWeight", recurrent_weight.values});
    if (bias.has_value()) {
      named_inputs.insert({"bias", bias->values});
    }
    if (recurrent_bias.has_value()) {
      named_inputs.insert({"recurrentBias", recurrent_bias->values});
    }
    if (peephole_weight.has_value()) {
      named_inputs.insert({"peepholeWeight", peephole_weight->values});
    }
    if (initial_hidden_state.has_value()) {
      named_inputs.insert({"initialHiddenState", initial_hidden_state->values});
    }
    if (initial_cell_state.has_value()) {
      named_inputs.insert({"initialCellState", initial_cell_state->values});
    }

    base::flat_map<std::string, std::vector<T>> named_outputs = BuildAndCompute(
        helper.context(), std::move(remote), builder.TakeGraphInfo(),
        std::move(named_inputs), expectation);

    if (expectation == BuildAndComputeExpectation::kSuccess) {
      for (size_t i = 0; i < outputs.size(); ++i) {
        VerifyIsEqual(named_outputs["output" + base::NumberToString(i)],
                      outputs[i]);
      }
    }
  }
};

// Test building and computing a graph with single operator lstm.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstm) {
  {
    // Test lstm with given bias and recurrent bias, activations = {relu, relu,
    // relu}.
    uint32_t steps = 2;
    uint32_t batch_size = 2;
    uint32_t input_size = 2;
    uint32_t direction_count = 1;
    uint32_t hidden_size = 1;
    LstmTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {-4, -3, -2, -1, 0, 1, 2, 3}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {direction_count, 4 * hidden_size, input_size},
                   .values = {1, 1, 1, 1, 1, 1, 1, 1}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {direction_count, 4 * hidden_size,
                                            hidden_size},
                             .values = {1, 1, 1, 1}},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {direction_count, 4 * hidden_size},
                               .values = {0.5, 0.5, 0.5, 0.5}},
        .recurrent_bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {direction_count, 4 * hidden_size},
                               .values = {0.5, 0.5, 0.5, 0.5}},
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size},
                     .values = {8, 216}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size},
                     .values = {4, 36}}}}
        .Test(*this);
  }
  {
    // Test lstm with given bias and peephole weight, activations = {relu, relu,
    // relu}.
    uint32_t steps = 2;
    uint32_t batch_size = 1;
    uint32_t input_size = 2;
    uint32_t direction_count = 1;
    uint32_t hidden_size = 2;
    LstmTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {steps, batch_size, input_size},
                  .values = {1, 2, 3, 4}},
        .weight = {.type = OperandDataType::kFloat32,
                   .dimensions = {direction_count, 4 * hidden_size, input_size},
                   .values = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}},
        .recurrent_weight = {.type = OperandDataType::kFloat32,
                             .dimensions = {direction_count, 4 * hidden_size,
                                            hidden_size},
                             .values = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                        1, 1, 1}},
        .steps = steps,
        .hidden_size = hidden_size,
        .bias =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {direction_count, 4 * hidden_size},
                               .values = {1, 1, 1, 1, 1, 1, 1, 1}},
        .peephole_weight =
            OperandInfo<float>{.type = OperandDataType::kFloat32,
                               .dimensions = {direction_count, 3 * hidden_size},
                               .values = {0, 0, 0, 0, 0, 0}},
        .attributes =
            {.activations = {mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu,
                             mojom::RecurrentNetworkActivation::kRelu}},
        .outputs = {{.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size},
                     .values = {2811392, 2811392}},
                    {.type = OperandDataType::kFloat32,
                     .dimensions = {direction_count, batch_size, hidden_size},
                     .values = {20672, 20672}}}}
        .Test(*this);
  }
  {
    // Test lstm with constant operands.
    uint32_t steps = 1;
    uint32_t batch_size = 2;
    uint32_t input_size = 1;
    uint32_t direction_count = 1;
    uint32_t hidden_size = 2;
    std::array<float, 2> input_data = {0, 1};
    std::array<float, 8> weight_data = {1, 1, 1, 1, 1, 1, 1, 1};
    std::array<float, 16> recurrent_weight_data = {1, 1, 1, 1, 1, 1, 1, 1,
                                                   1, 1, 1, 1, 1, 1, 1, 1};
    std::array<float, 6> peephole_weight_data = {0, 0, 0, 0, 0, 0};
    std::array<float, 4> initial_hidden_state_data = {0, 0, 0, 0};
    std::array<float, 4> initial_cell_state_data = {1, 1, 1, 1};
    std::vector<float> expected_data = {0, 0, 2, 2};

    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id = builder.BuildConstant(
        {steps, batch_size, input_size}, OperandDataType::kFloat32,
        base::as_byte_span(base::allow_nonunique_obj, input_data));
    OperandId weight_operand_id = builder.BuildConstant(
        {direction_count, 4 * hidden_size, input_size},
        OperandDataType::kFloat32,
        base::as_byte_span(base::allow_nonunique_obj, weight_data));
    OperandId recurrent_weight_operand_id = builder.BuildConstant(
        {direction_count, 4 * hidden_size, hidden_size},
        OperandDataType::kFloat32,
        base::as_byte_span(base::allow_nonunique_obj, recurrent_weight_data));

    LstmTester<float>::LstmAttributes attributes;
    attributes.peephole_weight_operand_id = builder.BuildConstant(
        {direction_count, 3 * hidden_size}, OperandDataType::kFloat32,
        base::as_byte_span(base::allow_nonunique_obj, peephole_weight_data));
    attributes.initial_hidden_state_operand_id = builder.BuildConstant(
        {direction_count, batch_size, hidden_size}, OperandDataType::kFloat32,
        base::as_byte_span(base::allow_nonunique_obj,
                           initial_hidden_state_data));
    attributes.initial_cell_state_operand_id = builder.BuildConstant(
        {direction_count, batch_size, hidden_size}, OperandDataType::kFloat32,
        base::as_byte_span(base::allow_nonunique_obj, initial_cell_state_data));
    attributes.activations = {mojom::RecurrentNetworkActivation::kRelu,
                              mojom::RecurrentNetworkActivation::kRelu,
                              mojom::RecurrentNetworkActivation::kRelu};

    OperandId output_a_operand_id = builder.BuildOutput(
        "output0", {direction_count, batch_size, hidden_size},
        OperandDataType::kFloat32);
    OperandId output_b_operand_id = builder.BuildOutput(
        "output1", {direction_count, batch_size, hidden_size},
        OperandDataType::kFloat32);
    std::vector<OperandId> output_operand_ids{output_a_operand_id,
                                              output_b_operand_id};
    builder.BuildLstm(input_operand_id, weight_operand_id,
                      recurrent_weight_operand_id,
                      std::move(output_operand_ids), steps, hidden_size,
                      std::move(attributes));

    base::flat_map<std::string, std::vector<float>> named_outputs =
        BuildAndCompute<float>(context(), std::move(remote),
                               builder.TakeGraphInfo(),
                               /*named_inputs=*/{});

    ASSERT_EQ(named_outputs.size(), 2u);
    VerifyFloatDataIsEqual(named_outputs["output0"], expected_data);
    VerifyFloatDataIsEqual(named_outputs["output1"], expected_data);
  }
}

struct LstmCellAttributes {
  std::optional<OperandId> bias_operand_id;
  std::optional<OperandId> recurrent_bias_operand_id;
  std::optional<OperandId> peephole_weight_operand_id;
  mojom::LstmWeightLayout layout = mojom::LstmWeightLayout::kIofg;
  std::vector<mojom::RecurrentNetworkActivation> activations = {
      mojom::RecurrentNetworkActivation::kSigmoid,
      mojom::RecurrentNetworkActivation::kTanh,
      mojom::RecurrentNetworkActivation::kTanh};
};

// TODO(crbug.com/331250158): Remove this test after the WPT conformance tests
// are completed.
// Test building and computing a graph with single operator lstmCell.
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeSingleOperatorLstmCell) {
  std::vector<float> expected_output0 = {150, 150, 810, 810};
  std::vector<float> expected_output1 = {30, 30, 90, 90};
  uint32_t batch_size = 2;
  uint32_t input_size = 2;
  uint32_t hidden_size = 2;
  std::vector<float> input_data = {1, 2, 3, 4};
  std::vector<float> weight_data(16, 1);
  std::vector<float> recurrent_weight_data(16, 1);
  std::vector<float> initial_hidden_state_data(4, 1);
  std::vector<float> initial_cell_state_data(4, 1);

  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id = builder.BuildInput(
      "input", {batch_size, input_size}, OperandDataType::kFloat32);
  OperandId weight_operand_id = builder.BuildInput(
      "weight", {4 * hidden_size, input_size}, OperandDataType::kFloat32);
  OperandId recurrent_weight_operand_id =
      builder.BuildInput("recurrentWeight", {4 * hidden_size, hidden_size},
                         OperandDataType::kFloat32);
  OperandId hidden_state_operand_id = builder.BuildInput(
      "hiddenState", {batch_size, hidden_size}, OperandDataType::kFloat32);
  OperandId cell_state_operand_id = builder.BuildInput(
      "cellState", {batch_size, hidden_size}, OperandDataType::kFloat32);

  LstmCellAttributes attributes;
  attributes.activations = {mojom::RecurrentNetworkActivation::kRelu,
                            mojom::RecurrentNetworkActivation::kRelu,
                            mojom::RecurrentNetworkActivation::kRelu};

  OperandId output_a_operand_id = builder.BuildOutput(
      "output0", {batch_size, hidden_size}, OperandDataType::kFloat32);
  OperandId output_b_operand_id = builder.BuildOutput(
      "output1", {batch_size, hidden_size}, OperandDataType::kFloat32);
  std::vector<OperandId> output_operand_ids{output_a_operand_id,
                                            output_b_operand_id};
  builder.BuildLstmCell(input_operand_id, weight_operand_id,
                        recurrent_weight_operand_id, hidden_state_operand_id,
                        cell_state_operand_id, std::move(output_operand_ids),
                        hidden_size, std::move(attributes));

  base::flat_map<std::string, base::span<const float>> named_inputs;
  named_inputs.insert({"input", input_data});
  named_inputs.insert({"weight", weight_data});
  named_inputs.insert({"recurrentWeight", recurrent_weight_data});
  named_inputs.insert({"hiddenState", initial_hidden_state_data});
  named_inputs.insert({"cellState", initial_cell_state_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  ASSERT_EQ(named_outputs.size(), 2u);
  VerifyFloatDataIsEqual(named_outputs["output0"], expected_output0);
  VerifyFloatDataIsEqual(named_outputs["output1"], expected_output1);
}

template <typename T>
struct MatmulTester {
  OperandInfo<T> input_a;
  OperandInfo<T> input_b;
  OperandInfo<T> output;

  void TestFusion(
      WebNNGraphImplBackendTest& test,
      std::optional<std::vector<uint32_t>> permutation_a,
      std::optional<std::vector<uint32_t>> permutation_b,
      std::optional<const FusibleOperationDescriptor> fusible_operation) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_a_operand_id =
        builder.BuildInput("input_a", input_a.dimensions, input_a.type);
    if (permutation_a) {
      std::vector<uint32_t> transposed_input_a_shape =
          PermuteArray(input_a.dimensions, permutation_a.value());
      OperandId transposed_input_a_id = builder.BuildIntermediateOperand(
          transposed_input_a_shape, input_a.type);
      builder.BuildTranspose(input_a_operand_id, transposed_input_a_id,
                             permutation_a.value());
      input_a_operand_id = transposed_input_a_id;
    }
    OperandId input_b_operand_id =
        builder.BuildInput("input_b", input_b.dimensions, input_b.type);
    if (permutation_b) {
      std::vector<uint32_t> transposed_input_b_shape =
          PermuteArray(input_b.dimensions, permutation_b.value());
      OperandId transposed_input_b_id = builder.BuildIntermediateOperand(
          transposed_input_b_shape, input_b.type);
      builder.BuildTranspose(input_b_operand_id, transposed_input_b_id,
                             permutation_b.value());
      input_b_operand_id = transposed_input_b_id;
    }

    OperandId output_operand_id;
    if (fusible_operation) {
      output_operand_id =
          builder.BuildIntermediateOperand(output.dimensions, output.type);
    } else {
      output_operand_id =
          builder.BuildOutput("output", output.dimensions, output.type);
    }

    builder.BuildMatmul(input_a_operand_id, input_b_operand_id,
                        output_operand_id);

    if (fusible_operation) {
      OperandId intermediate_operand_id = output_operand_id;
      output_operand_id =
          builder.BuildOutput("output", output.dimensions, output.type);
      BuildFusibleOperation(builder, fusible_operation.value(),
                            intermediate_operand_id, output_operand_id);
    }

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input_a", input_a.values});
    named_inputs.insert({"input_b", input_b.values});
    base::flat_map<std::string, std::vector<T>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyIsEqual(named_outputs["output"], output);
  }
};

// Test building and computing a graph of fusing standalone operations
// into matmul when possible.
TEST_F(WebNNGraphImplBackendTest, FuseStandaloneOperationsIntoMatmul) {
  // Test matmul with fusible transpose for input a.
  {
    MatmulTester<float>{
        .input_a = {.type = OperandDataType::kFloat32,
                    .dimensions = {1, 2, 3},
                    .values = {1, 2, 3, 4, 5, 6}},
        .input_b = {.type = OperandDataType::kFloat32,
                    .dimensions = {1, 2, 3},
                    .values = {1, 2, 3, 4, 5, 6}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 3, 3},
                   .values = {17, 22, 27, 22, 29, 36, 27, 36, 45}}}
        .TestFusion(*this,
                    /*transpose_a*/ std::vector<uint32_t>({0, 2, 1}),
                    /*transpose_b*/ std::nullopt,
                    /*activation*/ std::nullopt);
  }

  // Test matmul with fusible transpose for input b.
  {
    MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32,
                                    .dimensions = {1, 2, 3},
                                    .values = {1, 2, 3, 4, 5, 6}},
                        .input_b = {.type = OperandDataType::kFloat32,
                                    .dimensions = {1, 2, 3},
                                    .values = {1, 2, 3, 4, 5, 6}},
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 2, 2},
                                   .values = {14, 32, 32, 77}}}
        .TestFusion(*this,
                    /*transpose_a*/ std::nullopt,
                    /*transpose_b*/ std::vector<uint32_t>({0, 2, 1}),
                    /*activation*/ std::nullopt);
  }

  // Test matmul with fusible transpose for both input a and b.
  {
    MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32,
                                    .dimensions = {1, 3, 2},
                                    .values = {1, 2, 3, 4, 5, 6}},
                        .input_b = {.type = OperandDataType::kFloat32,
                                    .dimensions = {1, 2, 3},
                                    .values = {1, 2, 3, 4, 5, 6}},
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 2, 2},
                                   .values = {22, 49, 28, 64}}}
        .TestFusion(*this,
                    /*transpose_a*/ std::vector<uint32_t>({0, 2, 1}),
                    /*transpose_b*/ std::vector<uint32_t>({0, 2, 1}),
                    /*activation*/ std::nullopt);
  }

  // Test matmul with unfusible transpose for input a.
  {
    MatmulTester<float>{
        .input_a = {.type = OperandDataType::kFloat32,
                    .dimensions = {2, 3, 1},
                    .values = {1, 2, 3, 4, 5, 6}},
        .input_b = {.type = OperandDataType::kFloat32,
                    .dimensions = {1, 2, 3},
                    .values = {1, 2, 3, 4, 5, 6}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 3, 3},
                   .values = {17, 22, 27, 22, 29, 36, 27, 36, 45}}}
        .TestFusion(*this,
                    /*transpose_a*/ std::vector<uint32_t>({2, 1, 0}),
                    /*transpose_b*/ std::nullopt, /*activation*/ std::nullopt);
  }

  // Test matmul with 2-D * 2-D inputs, activation = linear.
  {
    MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32,
                                    .dimensions = {2, 2},
                                    .values = {1, 2, 3, 4}},
                        .input_b = {.type = OperandDataType::kFloat32,
                                    .dimensions = {2, 2},
                                    .values = {1, 2, 3, 4}},
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {2, 2},
                                   .values = {71, 101, 151, 221}}}
        .TestFusion(
            *this,
            /*transpose_a*/ std::nullopt, /*transpose_b*/ std::nullopt,
            /*activation*/
            FusibleOperationDescriptor{.kind = mojom::Operation::Tag::kLinear,
                                       .alpha = 10,
                                       .beta = 1});
  }

  // Test matmul that can fuse transpose a, b and linear.
  {
    MatmulTester<float>{.input_a = {.type = OperandDataType::kFloat32,
                                    .dimensions = {1, 3, 2},
                                    .values = {1, 2, 3, 4, 5, 6}},
                        .input_b = {.type = OperandDataType::kFloat32,
                                    .dimensions = {1, 2, 3},
                                    .values = {1, 2, 3, 4, 5, 6}},
                        .output = {.type = OperandDataType::kFloat32,
                                   .dimensions = {1, 2, 2},
                                   .values = {221, 491, 281, 641}}}
        .TestFusion(
            *this,
            /*transpose_a*/ std::vector<uint32_t>({0, 2, 1}),
            /*transpose_b*/ std::vector<uint32_t>({0, 2, 1}),
            /*activation*/
            FusibleOperationDescriptor{.kind = mojom::Operation::Tag::kLinear,
                                       .alpha = 10,
                                       .beta = 1});
  }
}

// Test building and computing a graph with two inputs and two constant in
// the following topology.
//    [input_a] [constant_a] [input_b] [constant_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplBackendTest, BuildMultipleInputsAppendingConstants) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
  std::vector<float> constant_data = {1, 1, 1, 1};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));

  // The order of inputs are [input_a, constant_a, input_b, constant_b].
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_a_operand_id, constant_a_operand_id,
                    intermediate_1_operand_id, GemmAttributes());
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_b_operand_id, constant_b_operand_id,
                    intermediate_2_operand_id, GemmAttributes());
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
                    output_operand_id, GemmAttributes());

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 2, 3, 4};
  named_inputs.insert({"input_a", input_data});
  named_inputs.insert({"input_b", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], {30, 30, 70, 70});
}

// Test building and computing a graph with two inputs and two constant in
// the following topology.
//    [constant_a] [input_a] [constant_b] [input_b]
//           \    /                \    /
//            gemm                  gemm
//                \                /
//                       gemm
TEST_F(WebNNGraphImplBackendTest, BuildMultipleConstantsAppendingInputs) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
  std::vector<float> constant_data = {1, 2, 3, 4};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));

  // The order of inputs are [constant_a, input_a, constant_b, input_b].
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(constant_a_operand_id, input_a_operand_id,
                    intermediate_1_operand_id, GemmAttributes());
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(constant_b_operand_id, input_b_operand_id,
                    intermediate_2_operand_id, GemmAttributes());
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(intermediate_1_operand_id, intermediate_2_operand_id,
                    output_operand_id, GemmAttributes());

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 1, 1, 1};
  named_inputs.insert({"input_a", input_data});
  named_inputs.insert({"input_b", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], {30, 30, 70, 70});
}

// Test building and computing a graph whose gemm operator takes a reshaped
// constant operand c in the following topology:
//                        [constant_c]
//                         |
//     [input_a] [input_b] reshape
//             \    |     /
//                 gemm
// This test case could reproduce the issue of ResNetV2 50 model of WebNN image
// classification sample:
// https://bugs.chromium.org/p/chromium/issues/detail?id=1509747
TEST_F(WebNNGraphImplBackendTest, BuildGemmWithReshapedConstantOperand) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {2, 2}, OperandDataType::kFloat32);
  std::vector<float> constant_data = {1, 1};
  OperandId constant_c_operand_id = builder.BuildConstant(
      {2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  // Reshape constant_c from [2] to [1, 2] and use it as operand c for gemm.
  OperandId reshape_operand_id =
      builder.BuildIntermediateOperand({1, 2}, OperandDataType::kFloat32);
  builder.BuildReshape(constant_c_operand_id, reshape_operand_id);
  GemmAttributes gemm_attributes;
  gemm_attributes.c_operand_id = reshape_operand_id;
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildGemm(input_a_operand_id, input_b_operand_id, output_operand_id,
                    gemm_attributes);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 2, 3, 4};
  named_inputs.insert({"input_a", input_data});
  named_inputs.insert({"input_b", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  VerifyFloatDataIsEqual(named_outputs["output"], {8, 11, 16, 23});
}

// Test building a graph whose add operator takes a reshaped
// constant operand b in the following topology:
//              [constant_b]
//                 |
//    [input_a]  reshape
//           \    /
//            add
TEST_F(WebNNGraphImplBackendTest, BuildAddWithReshapedConstantOperand) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32);
  std::vector<float> constant_data = {1, 1};
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  // Reshape constant_b from [2] to [1, 2] and use it as operand b for add.
  OperandId reshape_operand_id =
      builder.BuildIntermediateOperand({1, 2}, OperandDataType::kFloat32);
  builder.BuildReshape(constant_b_operand_id, reshape_operand_id);
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 input_a_operand_id, reshape_operand_id,
                                 output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 1, 1, 1};
  named_inputs.insert({"input_a", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));
  VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2});
}

// Test building and computing a graph whose relu operator only has a
// constant operand input, as the following topology:
//    [constant]
//         |
//       relu
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReluWithOnlyConstantInput) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  std::vector<float> constant_data = {-1, 0, 1};
  OperandId constant_operand_id = builder.BuildConstant(
      {3}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data));
  OperandId output_operand_id =
      builder.BuildOutput("output", {3}, OperandDataType::kFloat32);
  builder.BuildRelu(constant_operand_id, output_operand_id);

  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute<float>(context(), std::move(remote),
                             builder.TakeGraphInfo(),
                             /*named_inputs=*/{});
  VerifyFloatDataIsEqual(named_outputs["output"], {0, 0, 1});
}

// Test building and computing a graph whose add operator only has constant
// operand inputs, as the following topology:
//    [constant_a]  [constant_b]
//               \  /
//               add
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeAddWithOnlyConstantInputs) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  std::vector<float> constant_a_data = {1, 1, 1, 1};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_a_data));
  std::vector<float> constant_b_data = {2, 2, 2, 2};
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_b_data));
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 constant_a_operand_id, constant_b_operand_id,
                                 output_operand_id);

  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute<float>(context(), std::move(remote),
                             builder.TakeGraphInfo(),
                             /*named_inputs=*/{});
  VerifyFloatDataIsEqual(named_outputs["output"], {3, 3, 3, 3});
}

// Test building and computing a graph whose add and mul operators only have
// constant and intermediate operand inputs, as the following topology:
//    [constant_a]  [constant_b]
//               \  /
//               add    [constant_c]
//                  \  /
//                   mul
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeAddAndMulWithOnlyConstantInputs) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  std::vector<float> constant_a_data = {1, 1, 1, 1};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_a_data));
  std::vector<float> constant_b_data = {2, 2, 2, 2};
  OperandId constant_b_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_b_data));
  OperandId intermediate_operand_id =
      builder.BuildIntermediateOperand({2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 constant_a_operand_id, constant_b_operand_id,
                                 intermediate_operand_id);
  std::vector<float> constant_c_data = {3, 3, 3, 3};
  OperandId constant_c_operand_id = builder.BuildConstant(
      {2, 2}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_c_data));
  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kMul,
                                 intermediate_operand_id, constant_c_operand_id,
                                 output_operand_id);

  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute<float>(context(), std::move(remote),
                             builder.TakeGraphInfo(),
                             /*named_inputs=*/{});
  VerifyFloatDataIsEqual(named_outputs["output"], {9, 9, 9, 9});
}

struct Pool2dAttributes {
  std::vector<uint32_t> window_dimensions;
  std::vector<uint32_t> padding;
  std::vector<uint32_t> strides;
  std::vector<uint32_t> dilations;
};

// Test building a graph in the following topology.
//    [input_a] [input_b]
//           \    /
//            add
//             |
//            relu
//             |
//          max pooling
TEST_F(WebNNGraphImplBackendTest, BuildMaxPoolingAsThirdOperator) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {1, 1, 2, 2}, OperandDataType::kFloat32);
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 input_a_operand_id, input_b_operand_id,
                                 intermediate_1_operand_id);

  // Relu.
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildRelu(intermediate_1_operand_id, intermediate_2_operand_id);

  // Max pooling.
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildPool2d(mojom::Pool2d::Kind::kMaxPool2d,
                      intermediate_2_operand_id, output_operand_id,
                      Pool2dAttributes{.window_dimensions = {1, 1},
                                       .padding = {0, 0, 0, 0},
                                       .strides = {1, 1},
                                       .dilations = {1, 1}});

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 1, 1, 1};
  named_inputs.insert({"input_a", input_data});
  named_inputs.insert({"input_b", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));
  VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2});
}

// Test building a graph in the following topology.
//    [input_a] [input_b]
//           \    /
//            add
//             |
//          max pooling
//             |
//            relu
TEST_F(WebNNGraphImplBackendTest, BuildMaxPoolingAsSecondOperator) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32);
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {1, 1, 2, 2}, OperandDataType::kFloat32);
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 input_a_operand_id, input_b_operand_id,
                                 intermediate_1_operand_id);

  // Max pooling.
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildPool2d(mojom::Pool2d::Kind::kMaxPool2d,
                      intermediate_1_operand_id, intermediate_2_operand_id,
                      Pool2dAttributes{.window_dimensions = {1, 1},
                                       .padding = {0, 0, 0, 0},
                                       .strides = {1, 1},
                                       .dilations = {1, 1}});

  // Relu.
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildRelu(intermediate_2_operand_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 1, 1, 1};
  named_inputs.insert({"input_a", input_data});
  named_inputs.insert({"input_b", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));
  VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2});
}

// Test building a graph in the following topology.
//      [input_a]
//          |
//      max pooling
//                  [input_b]
//           \        /
//               add
//                |
//               relu
TEST_F(WebNNGraphImplBackendTest, BuildMaxPoolingAsFirstOperator) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_a_operand_id =
      builder.BuildInput("input_a", {1, 1, 2, 2}, OperandDataType::kFloat32);
  OperandId intermediate_1_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildPool2d(mojom::Pool2d::Kind::kMaxPool2d, input_a_operand_id,
                      intermediate_1_operand_id,
                      Pool2dAttributes{.window_dimensions = {1, 1},
                                       .padding = {0, 0, 0, 0},
                                       .strides = {1, 1},
                                       .dilations = {1, 1}});

  // Add operation.
  OperandId input_b_operand_id =
      builder.BuildInput("input_b", {1, 1, 2, 2}, OperandDataType::kFloat32);
  OperandId intermediate_2_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildElementWiseBinary(mojom::ElementWiseBinary::Kind::kAdd,
                                 intermediate_1_operand_id, input_b_operand_id,
                                 intermediate_2_operand_id);

  // Relu.
  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 1, 2, 2}, OperandDataType::kFloat32);
  builder.BuildRelu(intermediate_2_operand_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  std::vector<float> input_data = {1, 1, 1, 1};
  named_inputs.insert({"input_a", input_data});
  named_inputs.insert({"input_b", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));
  VerifyFloatDataIsEqual(named_outputs["output"], {2, 2, 2, 2});
}

// Test building and computing a graph with float 16 data type in the
// following topology.
//     [input_a]
//         |
//      reshape    [input_b]
//          \         /
//             concat
//               |
//             clamp
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeReshapeConcatAndClamp) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id1 =
      builder.BuildInput("input_a", {4, 3}, OperandDataType::kFloat16);
  OperandId input_operand_id2 =
      builder.BuildInput("input_b", {1, 1, 2, 3}, OperandDataType::kFloat16);

  OperandId reshape_operand_id =
      builder.BuildIntermediateOperand({1, 2, 2, 3}, OperandDataType::kFloat16);
  builder.BuildReshape(input_operand_id1, reshape_operand_id);

  OperandId concat_operand_id =
      builder.BuildIntermediateOperand({1, 3, 2, 3}, OperandDataType::kFloat16);
  builder.BuildConcat({reshape_operand_id, input_operand_id2},
                      concat_operand_id, 1);

  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 3, 2, 3}, OperandDataType::kFloat16);
  builder.BuildClamp(concat_operand_id, output_operand_id, 1.25, 8.75);

  base::flat_map<std::string, base::span<const Float16>> named_inputs;
  // [[ 1  2  3]
  //  [ 4  5  6]
  //  [ 7  8  9]
  //  [10 11 12]] with shape (4, 3)
  std::vector<Float16> input_data1 =
      Float16FromFloat32({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
  // [[[[-6 -5 -4]
  //    [-3 -2 -1]]]] with shape (1, 1, 2, 3)
  std::vector<Float16> input_data2 =
      Float16FromFloat32({-6, -5, -4, -3, -2, -1});

  named_inputs.insert({"input_a", input_data1});
  named_inputs.insert({"input_b", input_data2});
  base::flat_map<std::string, std::vector<Float16>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  // [[[[1.25 2.   3.  ]
  //    [4.   5.   6.  ]]
  //   [[7.   8.   8.75]
  //    [8.75 8.75 8.75]]
  //   [[1.25 1.25 1.25]
  //    [1.25 1.25 1.25]]]] with shape (1, 3, 2, 3)
  EXPECT_EQ(Float16ToFloat32(named_outputs["output"]),
            std::vector<float>({1.25, 2, 3, 4, 5, 6, 7, 8, 8.75, 8.75, 8.75,
                                8.75, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25}));
}

// Test building and computing a graph in the following topology.
//      [input]   [constant_a]
//          \          /
//             concat   [constant_b]
//               \           /
//                   concat
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeConcatWithConstants) {
  std::vector<float> expected_output = {0,  0,  0,  1,  2,  3,
                                        -1, -2, -3, -4, -5, -6};

  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 1, 1, 3}, OperandDataType::kFloat32);

  // [[[[1 2 3]]]] with shape (1, 1, 1, 3)
  std::vector<float> constant_data_a = {1, 2, 3};
  OperandId constant_a_operand_id = builder.BuildConstant(
      {1, 1, 1, 3}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data_a));

  // [[[[-1 -2 -3]
  //    [-4 -5 -6]]]] with shape (1, 1, 2, 3)
  std::vector<float> constant_data_b = {-1, -2, -3, -4, -5, -6};
  OperandId constant_b_operand_id = builder.BuildConstant(
      {1, 1, 2, 3}, OperandDataType::kFloat32,
      base::as_byte_span(base::allow_nonunique_obj, constant_data_b));

  OperandId concat_operand_id =
      builder.BuildIntermediateOperand({1, 1, 2, 3}, OperandDataType::kFloat32);
  builder.BuildConcat({input_operand_id, constant_a_operand_id},
                      concat_operand_id, 2);

  OperandId output_operand_id =
      builder.BuildOutput("output", {1, 2, 2, 3}, OperandDataType::kFloat32);
  builder.BuildConcat({concat_operand_id, constant_b_operand_id},
                      output_operand_id, 1);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  // [[[[0 0 0]]]] with shape (1, 1, 1, 3)
  std::vector<float> input_data = {0, 0, 0};

  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  // [[[[ 0  0  0]
  //    [ 1  2  3]]
  //   [[-1 -2 -3]
  //    [-4 -5 -6]]]] with shape (1, 2, 2, 3)
  VerifyFloatDataIsEqual(named_outputs["output"], expected_output);
}

template <typename T>
struct Resample2dTester {
  OperandInfo<T> input;
  struct Resample2dAttributes {
    mojom::Resample2d::InterpolationMode mode =
        mojom::Resample2d::InterpolationMode::kNearestNeighbor;
    std::optional<std::vector<float>> scales;
    std::vector<uint32_t> axes = {2, 3};
  };
  Resample2dAttributes attributes;
  OperandInfo<float> output;

  void Test(WebNNGraphImplBackendTest& test) {
    // Build the graph with mojo type.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        test.BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", input.dimensions, input.type);
    OperandId output_operand_id =
        builder.BuildOutput("output", output.dimensions, output.type);
    builder.BuildResample2d(input_operand_id, output_operand_id, attributes);

    base::flat_map<std::string, base::span<const T>> named_inputs;
    named_inputs.insert({"input", input.values});
    base::flat_map<std::string, std::vector<float>> named_outputs =
        BuildAndCompute(test.context(), std::move(remote),
                        builder.TakeGraphInfo(), std::move(named_inputs));

    VerifyFloatDataIsEqual(named_outputs["output"], output.values);
  }
};

// Test building and computing a graph with single operator resample2d.
#if BUILDFLAG(IS_WIN) && defined(ARCH_CPU_ARM_FAMILY)
// Test times out on Windows 11 / ARM bot, see https:  // crbug.com/381510750.
#define MAYBE_BuildAndComputeSingleOperatorResample2d \
  DISABLED_BuildAndComputeSingleOperatorResample2d
#else
#define MAYBE_BuildAndComputeSingleOperatorResample2d \
  BuildAndComputeSingleOperatorResample2d
#endif
TEST_F(WebNNGraphImplBackendTest,
       MAYBE_BuildAndComputeSingleOperatorResample2d) {
  // Test resample2d with "NearestNeighbor" mode, explicit scales = [2, 3] and
  // axes = [2, 3].
  {
    Resample2dTester<float>{
        .input = {.type = OperandDataType::kFloat32,
                  .dimensions = {1, 1, 2, 2},
                  // [[[[1 2]
                  //    [3 4]]]] with shape (1, 1, 2, 2)
                  .values = {1, 2, 3, 4}},
        .attributes = {.scales = std::vector<float>{2, 3}},
        .output = {.type = OperandDataType::kFloat32,
                   .dimensions = {1, 1, 4, 6},
                   // [[[[1 1 1 2 2 2]
                   //    [1 1 1 2 2 2]
                   //    [3 3 3 4 4 4]
                   //    [3 3 3 4 4 4]]]] with shape (1, 1, 4, 6)
                   .values = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
                              3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}}}
        .Test(*this);
  }
}

// Test building and computing a graph in the following topology.
//      [input]
//         |
//     transpose
//         |
//     transpose
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTwoTranspose) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);

  OperandId transpose_operand_id =
      builder.BuildIntermediateOperand({2, 1, 3, 4}, OperandDataType::kFloat32);
  builder.BuildTranspose(input_operand_id, transpose_operand_id, {1, 0, 2, 3});

  OperandId output_operand_id =
      builder.BuildOutput("output", {4, 3, 1, 2}, OperandDataType::kFloat32);
  builder.BuildTranspose(transpose_operand_id, output_operand_id, {3, 2, 1, 0});

  base::flat_map<std::string, base::span<const float>> named_inputs;
  // [[[[ -1  -2  -3  -4]
  //    [ -5  -6  -7  -8]
  //    [ -9 -10 -11 -12]]
  //   [[ 13  14  15  16]
  //    [ 17  18  19  20]
  //    [ 21  22  23  24]]]] with shape (1, 2, 3, 4)
  std::vector<float> input_data = {-1, -2,  -3,  -4,  -5, -6, -7, -8,
                                   -9, -10, -11, -12, 13, 14, 15, 16,
                                   17, 18,  19,  20,  21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  // [[[[ -1  13]]
  //   [[ -5  17]]
  //   [[ -9  21]]]
  //  [[[ -2  14]]
  //   [[ -6  18]]
  //   [[-10  22]]]
  //  [[[ -3  15]]
  //   [[ -7  19]]
  //   [[-11  23]]]
  //  [[[ -4  16]]
  //   [[ -8  20]]
  //   [[-12  24]]]] with shape (4, 3, 1, 2)
  VerifyFloatDataIsEqual(named_outputs["output"],
                         {-1, 13, -5, 17, -9,  21, -2, 14, -6, 18, -10, 22,
                          -3, 15, -7, 19, -11, 23, -4, 16, -8, 20, -12, 24});
}

// Test building and computing a graph in the following topology.
//      [input]
//         |
//     transpose
//         |
//       relu
TEST_F(WebNNGraphImplBackendTest, BuildAndComputeGraphWithTransposeAndRelu) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);

  OperandId transpose_operand_id =
      builder.BuildIntermediateOperand({4, 3, 1, 2}, OperandDataType::kFloat32);
  builder.BuildTranspose(input_operand_id, transpose_operand_id, {3, 2, 0, 1});

  OperandId output_operand_id =
      builder.BuildOutput("output", {4, 3, 1, 2}, OperandDataType::kFloat32);
  builder.BuildRelu(transpose_operand_id, output_operand_id);

  base::flat_map<std::string, base::span<const float>> named_inputs;
  // [[[[ -1  -2  -3  -4]
  //    [ -5  -6  -7  -8]
  //    [ -9 -10 -11 -12]]
  //   [[ 13  14  15  16]
  //    [ 17  18  19  20]
  //    [ 21  22  23  24]]]] with shape (1, 2, 3, 4)
  std::vector<float> input_data = {-1, -2,  -3,  -4,  -5, -6, -7, -8,
                                   -9, -10, -11, -12, 13, 14, 15, 16,
                                   17, 18,  19,  20,  21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));
  // [[[[ 0  13]]
  //   [[ 0  17]]
  //   [[ 0  21]]]
  //  [[[ 0  14]]
  //   [[ 0  18]]
  //   [[ 0  22]]]
  //  [[[ 0  15]]
  //   [[ 0  19]]
  //   [[ 0  23]]]
  //  [[[ 0  16]]
  //   [[ 0  20]]
  //   [[ 0  24]]]] wit shape (4, 3, 1, 2)
  VerifyFloatDataIsEqual(named_outputs["output"],
                         {0, 13, 0, 17, 0, 21, 0, 14, 0, 18, 0, 22,
                          0, 15, 0, 19, 0, 23, 0, 16, 0, 20, 0, 24});
}

// Test building and computing a graph in the following topology.
//      [input]
//         |
//     transpose
//         |
//      reshape
//         |
//      reshape
//         |
//     transpose
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeGraphWithTransposeAndTwoReshape) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 4}, OperandDataType::kFloat32);

  OperandId transpose_operand_id =
      builder.BuildIntermediateOperand({4, 3, 1, 2}, OperandDataType::kFloat32);
  builder.BuildTranspose(input_operand_id, transpose_operand_id, {3, 2, 0, 1});

  OperandId reshape_operand_id1 =
      builder.BuildIntermediateOperand({2, 2, 6}, OperandDataType::kFloat32);
  builder.BuildReshape(transpose_operand_id, reshape_operand_id1);

  OperandId reshape_operand_id2 =
      builder.BuildIntermediateOperand({12, 2}, OperandDataType::kFloat32);
  builder.BuildReshape(reshape_operand_id1, reshape_operand_id2);

  OperandId output_operand_id =
      builder.BuildOutput("output", {2, 12}, OperandDataType::kFloat32);
  builder.BuildTranspose(reshape_operand_id2, output_operand_id, {1, 0});

  base::flat_map<std::string, base::span<const float>> named_inputs;
  // [[[[ -1  -2  -3  -4]
  //    [ -5  -6  -7  -8]
  //    [ -9 -10 -11 -12]]
  //   [[ 13  14  15  16]
  //    [ 17  18  19  20]
  //    [ 21  22  23  24]]]] with shape (1, 2, 3, 4)
  std::vector<float> input_data = {-1, -2,  -3,  -4,  -5, -6, -7, -8,
                                   -9, -10, -11, -12, 13, 14, 15, 16,
                                   17, 18,  19,  20,  21, 22, 23, 24};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));

  // [[ -1  -5  -9  -2  -6 -10  -3  -7 -11  -4  -8 -12]
  //  [ 13  17  21  14  18  22  15  19  23  16  20  24]] wit shape (2, 12)
  VerifyFloatDataIsEqual(named_outputs["output"],
                         {-1, -5, -9, -2, -6, -10, -3, -7, -11, -4, -8, -12,
                          13, 17, 21, 14, 18, 22,  15, 19, 23,  16, 20, 24});
}

// Test building and computing a graph in the following topology.
//         [input]
//            |
//           relu
//          /    \
//     reshape    transpose
//        |           |
//    [output1]   [output2]
TEST_F(WebNNGraphImplBackendTest,
       BuildAndComputeGraphWithTransposeAndTwoOutputs) {
  // Build the mojom graph info.
  mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
      BindNewGraphBuilderRemote();
  GraphInfoBuilder builder(remote);
  OperandId input_operand_id =
      builder.BuildInput("input", {1, 2, 3, 2}, OperandDataType::kFloat32);
  OperandId relu_operand_id =
      builder.BuildIntermediateOperand({1, 2, 3, 2}, OperandDataType::kFloat32);
  builder.BuildRelu(input_operand_id, relu_operand_id);

  OperandId output1_operand_id =
      builder.BuildOutput("output1", {3, 4}, OperandDataType::kFloat32);
  OperandId output2_operand_id =
      builder.BuildOutput("output2", {1, 2, 2, 3}, OperandDataType::kFloat32);
  builder.BuildReshape(relu_operand_id, output1_operand_id);
  builder.BuildTranspose(relu_operand_id, output2_operand_id, {0, 3, 1, 2});

  base::flat_map<std::string, base::span<const float>> named_inputs;
  // [[[[ -1  -2]
  //    [ -5 -10]
  //    [ -7   0]]
  //   [[  1   2]
  //    [  3   6]
  //    [ 10  20]]]] with shape (1, 2, 3, 2)
  std::vector<float> input_data = {-1, -2, -5, -10, -7, 0, 1, 2, 3, 6, 10, 20};
  named_inputs.insert({"input", input_data});
  base::flat_map<std::string, std::vector<float>> named_outputs =
      BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                      std::move(named_inputs));
  // [[ 0  0  0  0]
  //  [ 0  0  1  2]
  //  [ 3  6 10 20]] with shape (3, 4)
  VerifyFloatDataIsEqual(named_outputs["output1"],
                         {0, 0, 0, 0, 0, 0, 1, 2, 3, 6, 10, 20});
  // [[[[ 0  0  0]
  //    [ 1  3 10]]
  //   [[ 0  0  0]
  //    [ 2  6 20]]]] with shape (1, 2, 2, 3)
  VerifyFloatDataIsEqual(named_outputs["output2"],
                         {0, 0, 0, 1, 3, 10, 0, 0, 0, 2, 6, 20});
}

// Test building and computing a graph which can't be automatically fused
// because the output of conv2d is used by two operations or as graph's output.
TEST_F(WebNNGraphImplBackendTest,
       MultipleOutputsCanNotFuseStandaloneActivation) {
  //     [input]
  //        |
  //       conv
  //      /    \
  //     /      \
  //   relu1    relu2
  //     |        |
  // [output1][output2]
  {
    // Build the mojom graph info.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
    OperandId filter_operand_id = builder.BuildConstant(
        {1, 1, 3, 3}, OperandDataType::kFloat32,
        base::as_byte_span(
            base::allow_nonunique_obj,
            {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}));
    OperandId conv2d_output_operand_id = builder.BuildIntermediateOperand(
        {1, 1, 5, 5}, OperandDataType::kFloat32);

    Conv2dTester<float>::Conv2dAttributes attributes{
        .padding = {1, 1, 1, 1},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {1},
                                   .values = {-100}},
    };

    std::optional<OperandId> bias_operand_id;
    if (attributes.bias.has_value()) {
      bias_operand_id = builder.BuildConstant(
          attributes.bias->dimensions, attributes.bias->type,
          base::as_byte_span(base::allow_nonunique_obj,
                             attributes.bias->values));
    }

    builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
                        filter_operand_id, conv2d_output_operand_id,
                        std::move(attributes), bias_operand_id);

    OperandId relu1_output_operand_id =
        builder.BuildOutput("output1", {1, 1, 5, 5}, OperandDataType::kFloat32);
    builder.BuildRelu(conv2d_output_operand_id, relu1_output_operand_id);

    OperandId relu2_output_operand_id =
        builder.BuildOutput("output2", {1, 1, 5, 5}, OperandDataType::kFloat32);
    builder.BuildRelu(conv2d_output_operand_id, relu2_output_operand_id);

    base::flat_map<std::string, base::span<const float>> named_inputs;

    named_inputs.insert(
        {"input", {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}});
    base::flat_map<std::string, std::vector<float>> named_outputs =
        BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                        std::move(named_inputs));

    std::vector<float> expected_output_data{0,  0,  0, 0,  0,  0,  0, 0,  0,
                                            0,  0,  0, 8,  17, 0,  0, 44, 53,
                                            62, 11, 0, 11, 17, 23, 0};
    VerifyFloatDataIsEqual(named_outputs["output1"], expected_output_data);
    VerifyFloatDataIsEqual(named_outputs["output2"], expected_output_data);
  }
  //     [input]
  //        |
  //       conv
  //      /    \
  //     /      \
  //  reshape   relu
  //     |        |
  // [output1][output2]
  {
    // Build the mojom graph info.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
    OperandId filter_operand_id = builder.BuildConstant(
        {1, 1, 3, 3}, OperandDataType::kFloat32,
        base::as_byte_span(
            base::allow_nonunique_obj,
            {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}));
    OperandId conv2d_output_operand_id = builder.BuildIntermediateOperand(
        {1, 1, 5, 5}, OperandDataType::kFloat32);

    Conv2dTester<float>::Conv2dAttributes attributes{
        .padding = {1, 1, 1, 1},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {1},
                                   .values = {-100}},
    };

    std::optional<OperandId> bias_operand_id;
    if (attributes.bias.has_value()) {
      bias_operand_id = builder.BuildConstant(
          attributes.bias->dimensions, attributes.bias->type,
          base::as_byte_span(base::allow_nonunique_obj,
                             attributes.bias->values));
    }

    builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
                        filter_operand_id, conv2d_output_operand_id,
                        std::move(attributes), bias_operand_id);

    OperandId reshape_output_operand_id =
        builder.BuildOutput("output1", {1, 5, 1, 5}, OperandDataType::kFloat32);
    builder.BuildReshape(conv2d_output_operand_id, reshape_output_operand_id);

    OperandId relu_output_operand_id =
        builder.BuildOutput("output2", {1, 1, 5, 5}, OperandDataType::kFloat32);
    builder.BuildRelu(conv2d_output_operand_id, relu_output_operand_id);

    base::flat_map<std::string, base::span<const float>> named_inputs;

    named_inputs.insert(
        {"input", {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}});
    base::flat_map<std::string, std::vector<float>> named_outputs =
        BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                        std::move(named_inputs));

    VerifyFloatDataIsEqual(
        named_outputs["output1"],
        {-88, -79, -73, -67, -76, -67, -46, -37, -28, -49, -37, -1, 8,
         17,  -19, -7,  44,  53,  62,  11,  -28, 11,  17,  23,  -16});
    VerifyFloatDataIsEqual(named_outputs["output2"],
                           {0,  0, 0, 0,  0,  0,  0,  0, 0,  0,  0,  0, 8,
                            17, 0, 0, 44, 53, 62, 11, 0, 11, 17, 23, 0});
  }
  //     [input]
  //        |
  //      conv2d
  //      /    \
  //     /      \
  //   relu      \
  //     |        \
  // [output1] [output2]
  {
    // Build the mojom graph info.
    mojo::AssociatedRemote<mojom::WebNNGraphBuilder> remote =
        BindNewGraphBuilderRemote();
    GraphInfoBuilder builder(remote);
    OperandId input_operand_id =
        builder.BuildInput("input", {1, 1, 5, 5}, OperandDataType::kFloat32);
    OperandId filter_operand_id = builder.BuildConstant(
        {1, 1, 3, 3}, OperandDataType::kFloat32,
        base::as_byte_span(
            base::allow_nonunique_obj,
            {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}));
    OperandId conv2d_output_operand_id = builder.BuildIntermediateOperand(
        {1, 1, 5, 5}, OperandDataType::kFloat32);

    Conv2dTester<float>::Conv2dAttributes attributes{
        .padding = {1, 1, 1, 1},
        .bias = OperandInfo<float>{.type = OperandDataType::kFloat32,
                                   .dimensions = {1},
                                   .values = {-100}},
    };

    std::optional<OperandId> bias_operand_id;
    if (attributes.bias.has_value()) {
      bias_operand_id = builder.BuildConstant(
          attributes.bias->dimensions, attributes.bias->type,
          base::as_byte_span(base::allow_nonunique_obj,
                             attributes.bias->values));
    }

    builder.BuildConv2d(mojom::Conv2d::Kind::kDirect, input_operand_id,
                        filter_operand_id, conv2d_output_operand_id,
                        std::move(attributes), bias_operand_id);
    builder.AddOutput("output2", conv2d_output_operand_id);

    OperandId relu_output_operand_id =
        builder.BuildOutput("output1", {1, 1, 5, 5}, OperandDataType::kFloat32);
    builder.BuildRelu(conv2d_output_operand_id, relu_output_operand_id);

    base::flat_map<std::string, base::span<const float>> named_inputs;

    named_inputs.insert(
        {"input", {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
                   13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}});
    base::flat_map<std::string, std::vector<float>> named_outputs =
        BuildAndCompute(context(), std::move(remote), builder.TakeGraphInfo(),
                        std::move(named_inputs));

    VerifyFloatDataIsEqual(named_outputs["output1"],
                           {0,  0, 0, 0,  0,  0,  0,  0, 0,  0,  0,  0, 8,
                            17, 0, 0, 44, 53, 62, 11, 0, 11, 17, 23, 0});
    VerifyFloatDataIsEqual(
        named_outputs["output2"],
        {-88, -79, -73, -67, -76, -67, -46, -37, -28, -49, -37, -1, 8,
         17,  -19, -7,  44,  53,  62,  11,  -28, 11,  17,  23,  -16});
  }
}

}  // namespace webnn::test