#!/bin/bash
CUR_DIR=$(dirname $(readlink -f $0))
ROOT_DIR=$CUR_DIR/..
APEX_DIR=$ROOT_DIR/apex
SUPPORTED_PY_VERSION=(3.7 3.8 3.9 3.10 3.11 3.12 3.13)
PY_VERSION='3.7'
DEFAULT_SCRIPT_ARGS_NUM=1
function parse_script_args() {
local args_num=0
if [[ "x${1}" = "x" ]]; then
return 0
fi
while true; do
if [[ "x${1}" = "x" ]]; then
break
fi
if [[ "$(echo "${1}"|cut -b1-|cut -b-2)" == "--" ]]; then
args_num=$((args_num+1))
fi
if [[ ${args_num} -eq ${DEFAULT_SCRIPT_ARGS_NUM} ]]; then
break
fi
shift
done
if [[ ${args_num} -lt ${DEFAULT_SCRIPT_ARGS_NUM} ]]; then
return 1
fi
while true; do
case "${1}" in
--python=*)
PY_VERSION=$(echo "${1}"|cut -d"=" -f2)
args_num=$((args_num-1))
shift
;;
-*)
echo "ERROR Unsupported parameters: ${1}"
return 1
;;
*)
if [ "x${1}" != "x" ]; then
echo "ERROR Unsupported parameters: ${1}"
return 1
fi
break
;;
esac
done
if [[ ${args_num} -ne 0 ]]; then
return 1
fi
}
function check_python_version() {
matched_py_version='false'
for ver in ${SUPPORTED_PY_VERSION[*]}; do
if [ "${PY_VERSION}" = "${ver}" ]; then
matched_py_version='true'
return 0
fi
done
if [ "${matched_py_version}" = 'false' ]; then
echo "${PY_VERSION} is an unsupported python version, we suggest ${SUPPORTED_PY_VERSION[*]}"
exit 1
fi
}
function main()
{
if ! parse_script_args "$@"; then
echo "Failed to parse script args. Please check your inputs."
exit 1
fi
check_python_version
cd $ROOT_DIR || exit
git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout 4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a
cd ..
cp $ROOT_DIR/patch/npu.patch $APEX_DIR
cd $APEX_DIR || exit
patch -p1 <npu.patch
rm -rf npu.patch
cp -r $ROOT_DIR/src/* $APEX_DIR
cp -r $ROOT_DIR/tests/* $APEX_DIR/tests
cd $APEX_DIR || exit
python"${PY_VERSION}" setup.py --cpp_ext bdist_wheel
}
set -e
main "$@"