* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef SINGLEINPUTGATEFACTORY_H
#define SINGLEINPUTGATEFACTORY_H
#include <memory>
#include <string>
#include <iostream>
#include <buffer/NetworkObjectBufferPool.h>
#include <executiongraph/descriptor/ResourceIDPOD.h>
#include <partition/ResultPartitionManager.h>
#include <taskmanager/OmniShuffleEnvironmentConfiguration.h>
#include "SingleInputGate.h"
#include "executiongraph/descriptor/InputGateDeploymentDescriptorPOD.h"
#include "buffer/NetworkMemoryBufferPool.h"
#include "OmniLocalInputChannel.h"
namespace omnistream {
using ::datastream::NetworkMemoryBufferPool;
class SingleInputGateFactory {
public:
SingleInputGateFactory() = default;
SingleInputGateFactory(ResourceIDPOD &taskExecutorResourceId,
std::shared_ptr<OmniShuffleEnvironmentConfiguration> config,
std::shared_ptr<ResultPartitionManager> partitionManager,
std::shared_ptr<NetworkObjectBufferPool> networkBufferPool,
std::shared_ptr<NetworkMemoryBufferPool> networkMemoryBufferPool
)
: taskExecutorResourceId(taskExecutorResourceId),
partitionManager(partitionManager),
networkObjectBufferPool(networkBufferPool),
networkMemoryBufferPool(networkMemoryBufferPool),
networkBufferSize(config->getNetworkBufferSize()),
partitionRequestInitialBackoff(config->getPartitionRequestInitialBackoff()),
partitionRequestMaxBackoff(config->getPartitionRequestMaxBackoff()),
networkBuffersPerChannel(config->getNetworkBuffersPerChannel()),
floatingNetworkBuffersPerGate(config->getFloatingNetworkBuffersPerGate()) {
}
class ChannelStatistics {
public:
int numLocalChannels;
int numRemoteChannels;
int numUnknownChannels;
ChannelStatistics() = default;
std::string toString() const
{
return "local:" + std::to_string(numLocalChannels) + ", remote:" + std::to_string(numRemoteChannels) +
", unknown: " + std::to_string(numUnknownChannels);
}
};
std::shared_ptr<SingleInputGate> create(std::string owningTaskName, int gateIndex, std::shared_ptr<InputGateDeploymentDescriptorPOD> igdd,
std::shared_ptr<PartitionProducerStateProvider> partitionProducerStateProvider, int taskType);
std::function<std::shared_ptr<BufferPool>()> createBufferPoolFactory(std::shared_ptr<BufferPoolFactory> networkBufferPool, int floatingNetworkBuffersPerGate);
void createInputChannels(std::string owningTaskName,
std::shared_ptr<InputGateDeploymentDescriptorPOD> inputGateDeploymentDescriptor,
std::shared_ptr<SingleInputGate> inputGate,
int consumedSubpartitionIndex);
std::shared_ptr<InputChannel> createInputChannel(std::shared_ptr<SingleInputGate> inputGate, int index,
ShuffleDescriptorPOD shuffleDescriptor,
std::shared_ptr<ChannelStatistics> channelStatistics,
int consumedSubpartitionIndex);
ResourceIDPOD getTaskExecutorResourceId() const { return taskExecutorResourceId; }
std::shared_ptr<ResultPartitionManager> getPartitionManager() const { return partitionManager; }
std::shared_ptr<NetworkObjectBufferPool> getNetworkBufferPool() const { return networkObjectBufferPool; }
int getNetworkBufferSize() const { return networkBufferSize; }
SingleInputGateFactory& setTaskExecutorResourceId(ResourceIDPOD& taskExecutorResourceId)
{
this->taskExecutorResourceId = taskExecutorResourceId;
return *this;
}
SingleInputGateFactory& setPartitionManager(std::shared_ptr<ResultPartitionManager> partitionManager)
{
this->partitionManager = partitionManager;
return *this;
}
SingleInputGateFactory& setNetworkBufferPool(std::shared_ptr<NetworkObjectBufferPool> networkBufferPool)
{
this->networkObjectBufferPool = networkBufferPool;
return *this;
}
SingleInputGateFactory& setNetworkBufferSize(int networkBufferSize_)
{
this->networkBufferSize = networkBufferSize_;
return *this;
}
std::string toString() const
{
return "SingleInputGateFactory{ taskExecutorResourceId=" + taskExecutorResourceId.toString() +
", partitionManager=" + partitionManager->toString() +
", networkBufferPool=" + networkObjectBufferPool->toString()+
", networkBufferSize=" + std::to_string(networkBufferSize) +
"}";
}
std::shared_ptr<OmniLocalInputChannel> createOriginalInputChannel(std::shared_ptr<SingleInputGate> inputGate,
int index, ResultPartitionIDPOD& partitionId);
int getNetworkBuffersPerChannel()
{
return networkBuffersPerChannel;
}
private:
ResourceIDPOD taskExecutorResourceId;
std::shared_ptr<ResultPartitionManager> partitionManager;
std::shared_ptr<NetworkObjectBufferPool> networkObjectBufferPool;
std::shared_ptr<NetworkMemoryBufferPool> networkMemoryBufferPool;
int networkBufferSize;
int partitionRequestInitialBackoff;
int partitionRequestMaxBackoff;
int networkBuffersPerChannel;
int floatingNetworkBuffersPerGate;
};
}
#endif