* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file distributed.cpp
* \brief
*/
#include "pybind_common.h"
using namespace npu::tile_fwk;
using namespace npu::tile_fwk::Distributed;
namespace pypto {
void BindDistributed(py::module_& m)
{
py::class_<ShmemTensor>(m, "ShmemTensor")
.def(py::init<>())
.def_readwrite("group", &ShmemTensor::group)
.def_readwrite("worldSize", &ShmemTensor::worldSize)
.def_readwrite("data", &ShmemTensor::data)
.def_readwrite("signal", &ShmemTensor::signal);
m.def(
"CreateShmemTensor",
[](const char* group, int64_t worldSize, DataType dataType, const Shape& shape, ShmemTensor& t) {
return Distributed::CreateShmemTensor(group, worldSize, dataType, shape, t);
},
py::arg("group"), py::arg("worldSize"), py::arg("dataType"), py::arg("shape"), py::arg("t"),
"Create shmem data.");
m.def(
"CreateShmemSignal",
[](const char* group, int64_t worldSize, ShmemTensor& t) {
return Distributed::CreateShmemSignal(group, worldSize, t);
},
py::arg("group"), py::arg("worldSize"), py::arg("t"), "Create shmem signal data.");
m.def(
"ShmemView",
[](const ShmemTensor& operand, const std::vector<int64_t>& shapes, const py::sequence& offsets) {
bool has_symbolic = false;
for (const auto& item : offsets) {
if (py::isinstance<SymbolicScalar>(item)) {
has_symbolic = true;
break;
}
}
if (has_symbolic) {
std::vector<SymbolicScalar> symbolic_offsets;
symbolic_offsets.reserve(py::len(offsets));
for (const auto& item : offsets) {
symbolic_offsets.push_back(item.cast<SymbolicScalar>());
}
return Distributed::ShmemView(operand, shapes, symbolic_offsets);
} else {
std::vector<int64_t> int_offsets;
int_offsets.reserve(py::len(offsets));
for (const auto& item : offsets) {
int_offsets.push_back(item.cast<int64_t>());
}
return Distributed::ShmemView(operand, shapes, int_offsets);
}
},
py::arg("operand"), py::arg("shapes"), py::arg("offsets"), "Create shmem view.");
m.def(
"ShmemView",
[](const ShmemTensor& operand, const std::vector<int64_t>& shapes, const py::sequence& newValidShapes,
const py::sequence& newOffsets) {
std::vector<SymbolicScalar> symbolic_newValidShapes;
symbolic_newValidShapes.reserve(py::len(newValidShapes));
for (const auto& item : newValidShapes) {
symbolic_newValidShapes.push_back(item.cast<SymbolicScalar>());
}
std::vector<SymbolicScalar> symbolic_offsets;
symbolic_offsets.reserve(py::len(newOffsets));
for (const auto& item : newOffsets) {
symbolic_offsets.push_back(item.cast<SymbolicScalar>());
}
return Distributed::ShmemView(operand, shapes, symbolic_newValidShapes, symbolic_offsets);
},
py::arg("operand"), py::arg("shapes"), py::arg("newValidShapes"), py::arg("newOffsets"),
"Create shmem view with valid shapes and offsets.");
m.def(
"ShmemPut",
[](const Tensor& src, const ShmemTensor& dst, const SymbolicScalar& dstRank, Distributed::AtomicType putOp,
const Tensor& pred) { return Distributed::ShmemPut(src, dst, dstRank, putOp, pred); },
py::arg("src"), py::arg("dst"), py::arg("dstRank"), py::arg("putOp"), py::arg("pred"),
"Put tensor to shmem with rank.");
m.def(
"ShmemGet",
[](const ShmemTensor& src, const SymbolicScalar& srcRank, const Tensor& pred,
DataType targetDataType = DataType::DT_BOTTOM) {
return Distributed::ShmemGet(src, srcRank, pred, targetDataType);
},
py::arg("src"), py::arg("srcRank"), py::arg("pred"), py::arg("targetDataType") = DataType::DT_BOTTOM,
"Get shmem data with rank.");
m.def(
"ShmemSignal",
[](const ShmemTensor& src, const SymbolicScalar& srcRank, const SymbolicScalar& targetRank, int32_t signal,
Distributed::AtomicType sigOp,
const Tensor& pred) { return Distributed::ShmemSignal(src, srcRank, targetRank, signal, sigOp, pred); },
py::arg("src"), py::arg("srcRank"), py::arg("targetRank"), py::arg("signal"), py::arg("sigOp"), py::arg("pred"),
"Signal shmem with consumer rank.");
m.def(
"ShmemSignalAll",
[](const ShmemTensor& src, const SymbolicScalar& srcRank, int32_t signal, Distributed::AtomicType sigOp,
const Tensor& pred) { return Distributed::ShmemSignalAll(src, srcRank, signal, sigOp, pred); },
py::arg("src"), py::arg("srcRank"), py::arg("signal"), py::arg("sigOp"), py::arg("pred"),
"Signal all ranks in shmem.");
m.def(
"ShmemWaitUntil",
[](const ShmemTensor& src, const SymbolicScalar& srcRank, OpType cmp, int32_t cmpValue, bool clearSignal,
const Tensor& pred) { return Distributed::ShmemWaitUntil(src, srcRank, cmp, cmpValue, clearSignal, pred); },
py::arg("src"), py::arg("srcRank"), py::arg("cmp"), py::arg("cmpValue"), py::arg("clearSignal"),
py::arg("pred"), "Wait shmem signal.");
m.def(
"ShmemClearData", [](const ShmemTensor& src, Tensor& pred) { return Distributed::ShmemClearData(src, pred); },
py::arg("src"), py::arg("pred"), "Clear shmem data.");
m.def(
"ShmemClearSignal",
[](const ShmemTensor& src, Tensor& pred) { return Distributed::ShmemClearSignal(src, pred); }, py::arg("src"),
py::arg("pred"), "Clear shmem signal.");
m.def(
"ShmemBarrier", [](const ShmemTensor& src, const Tensor& pred) { return Distributed::ShmemBarrier(src, pred); },
py::arg("src"), py::arg("pred"), "Barrier on shmem.");
m.def(
"ShmemLoad",
[](const ShmemTensor& src, const SymbolicScalar& srcRank, const Tensor& pred,
DataType nonShmemDataType = DataType::DT_BOTTOM) {
return Distributed::ShmemLoad(src, srcRank, pred, nonShmemDataType);
},
py::arg("src"), py::arg("srcRank"), py::arg("pred"), py::arg("nonShmemDataType") = DataType::DT_BOTTOM,
"Load shmem data to local.");
m.def(
"ShmemStore",
[](const Tensor& src, const ShmemTensor& dst, const SymbolicScalar& dstRank, Distributed::AtomicType putOp,
const Tensor& pred) { return Distributed::ShmemStore(src, dst, dstRank, putOp, pred); },
py::arg("src"), py::arg("dst"), py::arg("dstRank"), py::arg("putOp"), py::arg("pred"),
"Store local tensor to shmem.");
m.def(
"GetSymbolicScalarPeId", [](std::string group) { return GetHcclRankId(group); }, py::arg("group"),
"Get local rank id by groupname.");
}
}