#!/usr/bin/env bash
set -e

echo "[patch_triton_ascend] patching triton-ascend for CANN 9.0.0"

ASCEND_PKG=$(python3 -c "import triton, os; print(os.path.dirname(triton.__file__))")

PATCH_DIR=/home/work/AgentSDK/aura/third_party/patch

DRIVER_PATCH="$PATCH_DIR/triton-ascend_driver.patch"
NPU_PATCH="$PATCH_DIR/triton-ascend_npu_utils.patch"

if [ -z "$ASCEND_PKG" ]; then
    echo "[patch_triton_ascend] error: ASCEND_PKG empty"
    exit 1
fi

if [ ! -d "$ASCEND_PKG/backends/ascend" ]; then
    echo "[patch_triton_ascend] error: backends/ascend not found in $ASCEND_PKG"
    exit 1
fi

cd "$ASCEND_PKG/backends/ascend"

echo "[patch_triton_ascend] cwd=$(pwd)"

if [ -f "$DRIVER_PATCH" ]; then
    echo "[patch] applying driver patch"
    if patch -p0 -N < "$DRIVER_PATCH"; then
        echo "[patch] driver patch done"
    else
        echo "[patch] driver patch failed (or already applied)"
    fi
else
    echo "[patch] missing $DRIVER_PATCH"
fi

if [ -f "$NPU_PATCH" ]; then
    echo "[patch] applying npu_utils patch"
    if patch -p0 -N < "$NPU_PATCH"; then
        echo "[patch] npu_utils patch done"
    else
        echo "[patch] npu_utils patch failed"
    fi
else
    echo "[patch] missing $NPU_PATCH"
fi

echo "[patch_triton_ascend] done"