910e62b5创建于 1月15日历史提交
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef SERVICES_WEBNN_WEBNN_PENDING_CONSTANT_OPERAND_H_
#define SERVICES_WEBNN_WEBNN_PENDING_CONSTANT_OPERAND_H_

#include "base/component_export.h"
#include "base/containers/heap_array.h"
#include "base/containers/span.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "third_party/blink/public/common/tokens/tokens.h"

namespace webnn {

class WebNNConstantOperand;

// Manages the data associated with an `MLConstantOperand` which has been built
// by an `MLGraphBuilder` but not yet been included in an `MLGraph`. Notably,
// this class does not include a shape since the shape of the constant data will
// not be known until after constant folding optimizations have been performed.
//
// An instance of this class is owned by a `WebNNGraphBuilderImpl` while the
// graph is being built, and then will either be:
//  - destroyed, if graph-building fails or the resulting graph does not include
//    this constant operand, or
//  - converted into a `WebNNConstantOperand`, otherwise.
//
// TODO(crbug.com/349428379): Consider allowing this class to be extended by
// backend-specific implementations, which can stream the constant data into the
// form needed by the backend.
class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNPendingConstantOperand {
 public:
  // Create a constant operand from bytes with an unknown shape.
  WebNNPendingConstantOperand(blink::WebNNPendingConstantToken handle,
                              OperandDataType data_type,
                              base::span<const uint8_t> data);

  ~WebNNPendingConstantOperand();

  WebNNPendingConstantOperand(const WebNNPendingConstantOperand&) = delete;
  WebNNPendingConstantOperand& operator=(const WebNNPendingConstantOperand&) =
      delete;

  // Vend a real operand by giving this pending operand a concrete shape.
  // Returns `nullptr` if `descriptor` is not compatible with this.
  std::unique_ptr<WebNNConstantOperand> TakeAsConstantOperand(
      OperandDescriptor descriptor);

  bool IsValidWithDescriptor(OperandDescriptor descriptor) const;

  // Defines a "transparent" comparator so that unique_ptr keys to
  // WebNNPendingConstantOperand instances can be compared against tokens for
  // lookup in associative containers like base::flat_set.
  struct Comparator {
    using is_transparent = blink::WebNNPendingConstantToken;
    template <class Deleter = std::default_delete<WebNNPendingConstantOperand>>
    bool operator()(
        const std::unique_ptr<WebNNPendingConstantOperand, Deleter>& lhs,
        const std::unique_ptr<WebNNPendingConstantOperand, Deleter>& rhs)
        const {
      return lhs->handle() < rhs->handle();
    }

    template <class Deleter = std::default_delete<WebNNPendingConstantOperand>>
    bool operator()(const blink::WebNNPendingConstantToken& lhs,
                    const std::unique_ptr<WebNNPendingConstantOperand, Deleter>&
                        rhs) const {
      return lhs < rhs->handle();
    }

    template <class Deleter = std::default_delete<WebNNPendingConstantOperand>>
    bool operator()(
        const std::unique_ptr<WebNNPendingConstantOperand, Deleter>& lhs,
        const blink::WebNNPendingConstantToken& rhs) const {
      return lhs->handle() < rhs;
    }
  };

  const blink::WebNNPendingConstantToken& handle() const { return handle_; }

 private:
  blink::WebNNPendingConstantToken handle_;

  const OperandDataType data_type_;

  base::HeapArray<uint8_t> data_;
};

}  // namespace webnn

#endif  // SERVICES_WEBNN_WEBNN_PENDING_CONSTANT_OPERAND_H_