* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "SingleInputGateFactory.h"
#include "LocalInputChannel.h"
#include "RemoteInputChannel.h"
#include "OmniLocalInputChannel.h"
#include "LocalRecoveredInputChannel.h"
#include "RemoteRecoveredInputChannel.h"
#include "checkpoint/channel/ChannelStateWriterImpl.h"
namespace omnistream {
std::shared_ptr<SingleInputGate> SingleInputGateFactory::create(
std::string owningTaskName,
int gateIndex,
std::shared_ptr<InputGateDeploymentDescriptorPOD> igdd,
std::shared_ptr<PartitionProducerStateProvider> partitionProducerStateProvider,
int taskType)
{
std::function<std::shared_ptr<BufferPool>()> factoryFunc;
std::shared_ptr<SegmentProvider> segmentProvider;
if (taskType == 1) {
factoryFunc = createBufferPoolFactory(networkObjectBufferPool, floatingNetworkBuffersPerGate);
segmentProvider = networkObjectBufferPool;
} else if (taskType == 2) {
factoryFunc = createBufferPoolFactory(networkMemoryBufferPool, floatingNetworkBuffersPerGate);
segmentProvider = networkMemoryBufferPool;
}
LOG("new SingleInputGate will running");
std::shared_ptr<SingleInputGate> inputGate = std::make_shared<SingleInputGate>(
owningTaskName,
gateIndex,
igdd->getConsumedResultId(),
igdd->getConsumedPartitionType(),
igdd->getConsumedSubpartitionIndex(),
igdd->getShuffleDescriptors().size(),
partitionProducerStateProvider,
factoryFunc,
segmentProvider,
networkBufferSize);
LOG("createInputChannels will running");
createInputChannels(owningTaskName, igdd, inputGate, igdd->getConsumedSubpartitionIndex());
return inputGate;
}
std::function<std::shared_ptr<BufferPool>()> SingleInputGateFactory::createBufferPoolFactory(
std::shared_ptr<BufferPoolFactory> bufferPoolFactory, int floatingNetworkBuffersPerGate)
{
int numRequiredBuffers = 1;
int maxUsedBuffers = floatingNetworkBuffersPerGate;
std::function<std::shared_ptr<BufferPool>()> bufferPoolFactoryFunc = [=]() -> std::shared_ptr<BufferPool> {
LOG("createBufferPoolFactory function running");
return bufferPoolFactory->createBufferPool(numRequiredBuffers, maxUsedBuffers);
};
INFO_RELEASE("bufferPoolFactoryFunc size :" << sizeof(bufferPoolFactoryFunc));
return bufferPoolFactoryFunc;
}
void SingleInputGateFactory::createInputChannels(
std::string owningTaskName,
std::shared_ptr<InputGateDeploymentDescriptorPOD> inputGateDeploymentDescriptor,
std::shared_ptr<SingleInputGate> inputGate,
int consumedSubpartitionIndex
)
{
std::vector<ShuffleDescriptorPOD> shuffleDescriptors = inputGateDeploymentDescriptor->getShuffleDescriptors();
std::shared_ptr<ChannelStatistics> channelStatistics = std::make_shared<ChannelStatistics>();
std::vector<std::shared_ptr<InputChannel>> inputChannels(shuffleDescriptors.size());
for (size_t i = 0; i < inputChannels.size(); i++) {
inputChannels[i] =
createInputChannel(inputGate, i, shuffleDescriptors[i], channelStatistics, consumedSubpartitionIndex);
}
inputGate->setInputChannels(inputChannels);
LOG(owningTaskName << ": Created " << inputChannels.size() << " input channels");
}
std::shared_ptr<InputChannel> SingleInputGateFactory::createInputChannel(
std::shared_ptr<SingleInputGate> inputGate,
int index,
ShuffleDescriptorPOD shuffleDescriptor,
std::shared_ptr<ChannelStatistics> channelStatistics,
int consumedSubpartitionIndex)
{
ResourceIDPOD producerResourceId = shuffleDescriptor.getStoresLocalResourcesOn();
channelStatistics->numLocalChannels++;
if (producerResourceId == this->taskExecutorResourceId) {
INFO_RELEASE("CREATE A LOCAL RECOVERED INPUT CHANNEL#################################");
return std::make_shared<LocalRecoveredInputChannel>(
inputGate,
index,
shuffleDescriptor.getResultPartitionID(),
consumedSubpartitionIndex,
partitionManager,
partitionRequestInitialBackoff,
partitionRequestMaxBackoff,
std::shared_ptr<SimpleCounter>(),
std::shared_ptr<SimpleCounter>(),
getNetworkBuffersPerChannel());
;
}
INFO_RELEASE("CREATE A REMOTE RECOVERED INPUT CHANNEL#################################");
return std::make_shared<RemoteRecoveredInputChannel>(
inputGate,
index,
shuffleDescriptor.getResultPartitionID(),
partitionManager,
partitionRequestInitialBackoff,
partitionRequestMaxBackoff,
networkBuffersPerChannel,
std::shared_ptr<SimpleCounter>(),
std::shared_ptr<SimpleCounter>());
}
std::shared_ptr<OmniLocalInputChannel> SingleInputGateFactory::createOriginalInputChannel(
std::shared_ptr<SingleInputGate> inputGate, int index, ResultPartitionIDPOD& partitionId)
{
std::shared_ptr<ChannelStateWriter> stateWriter = std::make_shared<ChannelStateWriterImpl>();
return std::make_shared<OmniLocalInputChannel>(
inputGate,
index,
partitionId,
partitionManager,
partitionRequestInitialBackoff,
partitionRequestMaxBackoff,
networkBuffersPerChannel,
std::shared_ptr<SimpleCounter>(),
std::shared_ptr<SimpleCounter>(),
stateWriter);
}
}