#pragma once
#include <c10d/ProcessGroup.hpp>
#include <c10d/comm.hpp>
namespace c10d {
enum class BuiltinCommHookType {
ALLREDUCE = 1,
FP16_COMPRESS = 2,
};
class AllReduceCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public:
explicit AllReduceCommHook(c10::intrusive_ptr<ProcessGroup> state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~AllReduceCommHook() override = default;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};
class FP16CompressCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public:
explicit FP16CompressCommHook(c10::intrusive_ptr<ProcessGroup> state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~FP16CompressCommHook() override = default;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};
class _AllReduceBySumCommHook
: public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public:
explicit _AllReduceBySumCommHook(c10::intrusive_ptr<ProcessGroup> state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~_AllReduceBySumCommHook() override = default;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
};
}