#!/bin/bash
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python
set -ex
export PYTHONBUFFERED=16
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
HAS_NVLINK=1
else
HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
source "${SCRIPT_DIR}/models/kimi-k2.sh"
CKPT_ARGS=(
--hf-checkpoint $BASE_DIR/Kimi-K2-Instruct/
--ref-load $BASE_DIR/Kimi-K2_torch_dist/
--load $BASE_DIR/Kimi-K2_slime/
--save $BASE_DIR/Kimi-K2_slime/
--save-interval 20
)
ROLLOUT_ARGS=(
--prompt-data $BASE_DIR/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--rm-type math
--num-rollout 100
--rollout-batch-size 128
--n-samples-per-prompt 8
--rollout-max-response-len 32768
--rollout-temperature 1
--over-sampling-batch-size 256
--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
--num-steps-per-rollout 4
--balance-data
)
EVAL_ARGS=(
--eval-interval 20
--eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl
--n-samples-per-eval-prompt 8
--eval-max-response-len 32768
--eval-top-p 1
)
PERF_ARGS=(
--tensor-model-parallel-size 8
--sequence-parallel
--pipeline-model-parallel-size 8
--context-parallel-size 4
--expert-model-parallel-size 32
--expert-tensor-parallel-size 1
--decoder-last-pipeline-num-layers 5
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
--use-dynamic-batch-size
--max-tokens-per-gpu 16384
)
GRPO_ARGS=(
--advantage-estimator grpo
--use-kl-loss
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
)
OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
--optimizer-cpu-offload
--overlap-cpu-optimizer-d2h-h2d
--use-precision-aware-optimizer
)
WANDB_ARGS=(
)
SGLANG_ARGS=(
--rollout-num-gpus-per-engine 16
--sglang-mem-fraction-static 0.7
--sglang-enable-dp-attention
--sglang-dp-size 8
--sglang-moe-dense-tp-size 1
--sglang-enable-dp-lm-head
--sglang-ep-size 16
--sglang-server-concurrency 1024
)
MISC_ARGS=(
--attention-dropout 0.0
--hidden-dropout 0.0
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
--attention-backend flash
--moe-enable-deepep
--moe-token-dispatcher-type flex
)
RUNTIME_ENV_JSON="{
\"env_vars\": {
\"PYTHONPATH\": \"/root/Megatron-LM/\",
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
\"no_proxy\": \"${no_proxy}\",
\"MASTER_ADDR\": \"${MASTER_ADDR}\"
}
}"
ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json="${RUNTIME_ENV_JSON}" \
-- python3 train.py \
--actor-num-nodes 32 \
--actor-num-gpus-per-node 8 \
--colocate \
--update-weight-buffer-size $(( 4 * 512 * 1024 * 1024))
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}