#!/bin/bash
set -e
CUR_DIR=$(dirname $(readlink -f $0))
PY_VERSION='3.8'
PYTORCH_VERSION='master'
DEFAULT_SCRIPT_ARGS_NUM_MAX=2
function parse_script_args() {
local args_num=0
while true; do
if [[ "x${1}" = "x" ]]; then
break
fi
if [[ "$(echo "${1}"|cut -b1-|cut -b-2)" == "--" ]]; then
args_num=$((args_num+1))
fi
if [[ "x${2}" = "x" ]]; then
break
fi
if [[ "$(echo "${2}"|cut -b1-|cut -b-2)" == "--" ]]; then
args_num=$((args_num+1))
fi
if [[ ${args_num} -eq ${DEFAULT_SCRIPT_ARGS_NUM_MAX} ]]; then
break
fi
done
while true; do
case "${1}" in
--python=*)
PY_VERSION=$(echo "${1}"|cut -d"=" -f2)
args_num=$((args_num-1))
shift
;;
--pytorch=*)
PYTORCH_VERSION=$(echo "${1}"|cut -d"=" -f2)
args_num=$((args_num-1))
shift
;;
-*)
echo "ERROR Unsupported parameters: ${1}"
return 1
;;
*)
if [ "x${1}" != "x" ]; then
echo "ERROR Unsupported parameters: ${1}"
return 1
fi
break
;;
esac
done
if [[ ${args_num} -ne 0 ]]; then
return 1
fi
}
function checkout_pytorch_branch() {
cd ${PYTORCH_PATH}
current_torch_branch=$(git symbolic-ref --short HEAD)
if [ "${current_torch_branch}" != "${PYTORCH_VERSION}" ]; then
if [ -d ${PYTORCH_PATH}/third_party/op-plugin ]; then
rm -r ${PYTORCH_PATH}/third_party/op-plugin
fi
echo "checkout to torch expected-branch[ ${PYTORCH_VERSION} ] "
git checkout "${PYTORCH_VERSION}" --recurse-submodules;
git checkout .;git clean -fdx;
fi
cd ${CUR_DIR}/../
}
function main()
{
if ! parse_script_args "$@"; then
echo "Failed to parse script args. Please check your inputs."
exit 1
fi
CODE_ROOT_PATH=${CUR_DIR}/../
BUILD_PATH=${CODE_ROOT_PATH}/build
PYTORCH_PATH=${BUILD_PATH}/pytorch
if [ ! -d ${PYTORCH_PATH} ]; then
if [ -d ${BUILD_PATH} ]; then
rm -r ${BUILD_PATH}
fi
git clone -b ${PYTORCH_VERSION} https://gitcode.com/ascend/pytorch.git ${PYTORCH_PATH}
fi
checkout_pytorch_branch
cd ${PYTORCH_PATH}
git submodule update --init --depth=1 --recursive
PYTORCH_THIRD_PATH=${PYTORCH_PATH}/third_party/op-plugin
if [ -d ${PYTORCH_THIRD_PATH}/op_plugin ]; then
rm -r ${PYTORCH_THIRD_PATH}/*
else
mkdir -p ${PYTORCH_THIRD_PATH}
fi
cp -rf ${CODE_ROOT_PATH}/op_plugin ${PYTORCH_THIRD_PATH}/
cp -rf ${CODE_ROOT_PATH}/codegen ${PYTORCH_THIRD_PATH}/
cp -rf ${CODE_ROOT_PATH}/torchnpugen ${PYTORCH_THIRD_PATH}/
cp -rf ${CODE_ROOT_PATH}/*.sh ${PYTORCH_THIRD_PATH}/
cp -rf ${CODE_ROOT_PATH}/test ${PYTORCH_THIRD_PATH}/
export BUILD_WITHOUT_SHA=1
if [[ "${PYTORCH_VERSION}" == v1.11.0* ]] || [[ "${PYTORCH_VERSION}" == v2.0.1* ]]; then
bash ${PYTORCH_PATH}/ci/build.sh --python=${PY_VERSION}
else
bash ${PYTORCH_PATH}/ci/build.sh --python=${PY_VERSION} --disable_torchair --disable_rpc
fi
if [ -d ${CODE_ROOT_PATH}/dist ]; then
rm -r ${CODE_ROOT_PATH}/dist
fi
cp -rf ${PYTORCH_PATH}/dist ${CODE_ROOT_PATH}
exit 0
}
main "$@"