#!/bin/bash
CURRENT_DIR=$(pwd)
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
PROJECT_ROOT=$( dirname $( dirname $(dirname "$SCRIPT_DIR")))
UTILS_PATH=${PROJECT_ROOT}/examples/utils
CSV_FILE="${SCRIPT_DIR}/test_shapes.csv"
IFS=',' read -ra DEVICE_ID_LIST <<< "$1"
PE_SIZE=${#DEVICE_ID_LIST[@]}
if [ $PE_SIZE -gt 8 ]; then
echo "PE size is illegal"
exit 1
fi
cd ${PROJECT_ROOT}/examples/matmul_reduce_scatter_padding/
DATA_DIR=`realpath ./out`
echo "DATA_DIR: $DATA_DIR"
EXEC_BIN=${PROJECT_ROOT}/build/bin/matmul_reduce_scatter_padding
source ${PROJECT_ROOT}/install/set_env.sh
tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do
echo "Processing test case: M=${M}, K=${K}, N=${N}"
rm -rf ./out/*.bin
python3 ${UTILS_PATH}/gen_data.py 2 1 ${PE_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR}
IPPORT="tcp://127.0.0.1:8899"
export SHMEM_UID_SESSION_ID=127.0.0.1:8899
for (( idx = 0; idx < ${PE_SIZE}; idx = idx + 1 )); do
${EXEC_BIN} "$PE_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" &
done
wait
python3 ${UTILS_PATH}/verify_result.py ${DATA_DIR}/aclshmem_output.bin ${DATA_DIR}/golden.bin 1 ${M} ${N} ${K} ${DATA_DIR}/torch_output.bin
[[ $? -ne 0 ]] && exit 1
done
cd ${CURRENT_DIR}