* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASYSTEM_COMMON_DEVICE_COMM_WRAPPER_H
#define DATASYSTEM_COMMON_DEVICE_COMM_WRAPPER_H
#include <atomic>
#include <memory>
#include "datasystem/common/device/comm_wrapper_base.h"
#include "datasystem/common/log/trace.h"
#include "datasystem/common/util/raii.h"
namespace datasystem {
constexpr uint32_t P2P_RANK_NUM = 2;
constexpr uint32_t P2P_SEND_RANK = 0;
constexpr uint32_t P2P_RECV_RANK = 1;
class CommWrapper : public CommWrapperBase {
public:
explicit CommWrapper(const std::string &commId, int localDeviceId, int remoteDeviceId,
std::shared_ptr<HcclCommMagr> &threadControl, DeviceResourceManager *resourceMgr);
~CommWrapper();
void ShutDown() override;
* @brief Init communicator. Special implementation based on this communication domain class.
* @param[in] numRanks The number of ranks.
* @param[in] rootInfo The root info.
* @param[in] rank The rank in local.
* @return Status of the call.
*/
Status InitComm(int numRanks, CommRootInfo &rootInfo, int rank);
* @brief Init communicator.
* @param[in] direction own transmission direction.
* @param[in] rootInfo The root info.
* @return Status of the call.
*/
Status InitCommunicator(CommRootInfo &rootInfo, const CommDirection direction, bool isSameNode) override;
* @brief P2P send the data to the receiving side.
* @param[in] blobs The list of the blob info.
* @param[in] event The DeviceRtEvent wrapper.
* @return Status of the call
*/
Status P2PSend(const std::vector<Blob> &blobs, const std::shared_ptr<DeviceRtEventWrapper> &event,
aclrtStream stream) override;
* @brief P2P recv the data from the sending side.
* @param[in] blobs The list of the blob info.
* @param[in] event The DeviceRtEvent wrapper.
* @param[in] stream The stream of runtime context.
* @return Status of the call
*/
Status P2PRecv(const std::vector<Blob> &blobs, const std::shared_ptr<DeviceRtEventWrapper> &event,
aclrtStream stream) override;
* @brief Queries whether an error occurs in the communication domain.
* @return The status of invocation.
*/
Status GetCommAsyncError() override;
* @brief Warm up the communicator wrapper in the send side.
* Attention! The communication interface has limitations.
* Suppose thread A creates communicator a1, but does not call the send/recv interfaces.
* Thread B also creates communicator b1 and calls the send/recv interface,
* then communicator a1 will not work properly.
* So we need to call send/recv immediately after creating the communicator to establish a socket,
* and ensure that the communication domain can be used normally.
* @param[in] eventType The p2p event type: SEND or RECV
* @return The status of call.
*/
Status WarmUpComm(CommDirection eventType) override;
* @brief Creating communicator rootinfo.
* @param[in] rootInfo Transfer a blank rootinfo, create a reference, and transfer a value.
* @return Status of the call.
*/
Status CreateRootInfo(CommRootInfo &rootInfo) override;
private:
* @brief Check if the communication domain pointer is null.
* @param[in] comm The communicator ptr.
* @return Status of the call
*/
Status CheckCommPtr(const void *ptr);
};
}
#endif