/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "ascir/Dialect/Asc/IR/Asc.h"
#include "ascir/Dialect/Asc/Transforms/Passes.h"
#include "ascir/Dialect/Asc/Utils/Utils.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace ascendc {
#define GEN_PASS_DEF_INSERTSYNC
#include "ascir/Dialect/Asc/Transforms/Passes.h.inc"
} // namespace ascendc
} // namespace mlir

using namespace mlir;

namespace {

Value findQueue(Value tensor)
{
    if (auto op = tensor.getDefiningOp<ascendc::TQueBindAllocTensorOp>()) {
        return op.getQueue();
    }
    if (auto op = tensor.getDefiningOp<ascendc::TQueBindDequeTensorOp>()) {
        return op.getQueue();
    }
    return nullptr;
}

void enqueueTensors(func::FuncOp funcOp)
{
    funcOp.walk([](ascendc::OpWithDst op) {
        auto tensor = op.getDst();
        if (!tensor || !isa<ascendc::BaseTensorType>(tensor.getType()) ||
            isa_and_present<ascendc::GlobalTensorOp>(tensor.getDefiningOp()))
            return;
        OpBuilder builder(op.getContext());
        builder.setInsertionPointAfter(op);
        if (auto queue = findQueue(tensor)) {
            builder.create<ascendc::TQueBindEnqueTensorOp>(op.getLoc(), queue, tensor);
            return;
        }
        builder.create<ascendc::PipeBarrierOp>(op.getLoc(), ascendc::Pipe::PIPE_V);
    });
}

void createSetGetValueSync(bool isBefore, OpBuilder& builder, Location loc)
{
    ascendc::HardEvent currentEvent = isBefore ? ascendc::HardEvent::V_S : ascendc::HardEvent::S_V;
    Value pipe = builder.create<ascendc::PipeOp>(loc);
    auto eventId =
        builder.create<ascendc::TPipeFetchEventIDOp>(loc, builder.getI8Type(), pipe, currentEvent).getResult();
    builder.create<ascendc::SetFlagOp>(loc, currentEvent, eventId);
    builder.create<ascendc::WaitFlagOp>(loc, currentEvent, eventId);
}

void syncGetValueOp(func::FuncOp& funcOp)
{
    funcOp.walk([](ascendc::APIOp op) {
        if (isa<ascendc::LocalTensorGetValueOp, ascendc::GlobalTensorGetValueOp>(op)) {
            auto loc = op.getLoc();
            OpBuilder builder(op);
            createSetGetValueSync(true, builder, loc);
            builder.setInsertionPointAfter(op);
            createSetGetValueSync(false, builder, loc);
        }
    });
}

void syncSetValueOp(func::FuncOp& funcOp)
{
    funcOp.walk([](ascendc::APIOp op) {
        if (isa<ascendc::LocalTensorSetValueOp, ascendc::GlobalTensorSetValueOp>(op)) {
            auto loc = op.getLoc();
            OpBuilder builder(op);
            if (auto forOp = op->getParentOfType<scf::ForOp>()) {
                constexpr unsigned oneOpOneYield = 2U;
                if (forOp.getBody()->getOperations().size() == oneOpOneYield) {
                    builder.setInsertionPoint(forOp);
                    createSetGetValueSync(true, builder, loc);
                    builder.setInsertionPointAfter(forOp);
                    createSetGetValueSync(false, builder, loc);
                    return;
                }
            }
            createSetGetValueSync(true, builder, loc);
            builder.setInsertionPointAfter(op);
            createSetGetValueSync(false, builder, loc);
        }
    });
}

bool reEnque(OpBuilder& b, Location loc, ascendc::TQueBindEnqueTensorOp enq, ascendc::TQueBindDequeTensorOp deq)
{
    DominanceInfo di;
    if (!di.dominates(enq, deq)) {
        auto* enqParent = deq->getParentRegion()->findAncestorOpInRegion(*enq);
        if (!enqParent) {
            enq.emitOpError("failed to be hoisted to tensor deque op scope");
            return false;
        }
        b.setInsertionPointAfter(enqParent);
        b.create<ascendc::TQueBindEnqueTensorOp>(loc, enq.getQueue(), enq.getTensor());
        enq.erase();
    }
    return true;
}

bool dequeueTensors(Region& region)
{
    DominanceInfo di;
    bool res = true;
    for (Block& block : region) {
        for (Operation& op : llvm::make_early_inc_range(block)) {
            auto enq = dyn_cast<ascendc::TQueBindEnqueTensorOp>(op);
            if (!enq) {
                for (Region& inner : op.getRegions()) {
                    res &= dequeueTensors(inner);
                }
                continue;
            }
            auto tensor = enq.getTensor();
            std::vector<Operation*> users;
            llvm::copy_if(tensor.getUsers(), std::back_inserter(users), [&](Operation* user) {
                return !isa<ascendc::TQueBindFreeTensorOp>(user) && ascendc::opPrecedes(enq, user, di);
            });
            if (users.empty()) {
                continue;
            }
            Operation* firstUser = *std::min_element(users.begin(), users.end(), [&](Operation* lhs, Operation* rhs) {
                return ascendc::opPrecedes(lhs, rhs, di);
            });
            auto* userInSameRegion = enq->getParentRegion()->findAncestorOpInRegion(*firstUser);
            if (userInSameRegion) {
                firstUser = userInSameRegion;
            }
            OpBuilder builder(firstUser);
            auto deq = builder.create<ascendc::TQueBindDequeTensorOp>(enq.getLoc(), tensor.getType(), enq.getQueue());
            if (!reEnque(builder, op.getLoc(), enq, deq)) {
                return false;
            }
            tensor.replaceUsesWithIf(deq.getTensor(), [&](OpOperand& opnd) {
                auto* owner = opnd.getOwner();
                return llvm::is_contained(users, owner);
            });
        }
    }
    return res;
}

void canonicalizeBarriers(func::FuncOp funcOp)
{
    auto builder = OpBuilder::atBlockTerminator(&funcOp.getFunctionBody().back());
    builder.create<ascendc::PipeBarrierOp>(builder.getUnknownLoc(), ascendc::Pipe::PIPE_ALL);
    RewritePatternSet patterns(funcOp.getContext());
    ascendc::PipeBarrierOp::getCanonicalizationPatterns(patterns, funcOp.getContext());
    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

struct InsertSyncPass : public ascendc::impl::InsertSyncBase<InsertSyncPass> {
public:
    void runOnOperation() override
    {
        func::FuncOp funcOp = getOperation();
        if (funcOp.isDeclaration()) {
            return;
        }
        enqueueTensors(funcOp);
        if (!dequeueTensors(funcOp.getRegion())) {
            signalPassFailure();
            return;
        }
        syncGetValueOp(funcOp);
        syncSetValueOp(funcOp);
        canonicalizeBarriers(funcOp);
    }
};

} // namespace

namespace mlir {
namespace ascendc {
std::unique_ptr<Pass> createInsertSyncPass() { return std::make_unique<InsertSyncPass>(); }
} // namespace ascendc
} // namespace mlir