@@ -1,6 +1,11 @@
export NCCL_SOCKET_IFNAME=bond0
export NCCL_IB_HCA=mlx5_2,mlx5_3
+export TASK_QUEUE_ENABLE=2
+export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
+export CPU_AFFINITY_CONF=1
+export TORCH_HCCL_ZERO_COPY=1 # A2 需要注释这一配置
+export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
# used for check save when communication
export NCCL_BLOCKING_WAIT=1
export NCCL_ASYNC_ERROR_HANDLING=1
@@ -13,13 +18,13 @@ freeze_module_list=''
base_vlm=playground/Pretrained_models/Qwen3-VL-4B-Instruct
config_yaml=./examples/Robotwin/train_files/starvla_cotrain_robotwin_abs.yaml
run_root_dir=./results/Checkpoints
-data_mix=robotwin_all_50
+data_mix=robotwin
run_id=0129_${data_mix}_qwen3OFT_all
#
###########################################################################################
-# export WANDB_MODE=disabled
+export WANDB_MODE=disabled
output_dir=${run_root_dir}/${run_id}
mkdir -p ${output_dir}
@@ -4,7 +4,7 @@ tiktoken
einops
transformers_stream_generator==0.0.4
scipy
-torchvision==0.21.0
+torchvision==0.22.1
setuptools==80.9.0
pillow
tensorboard
@@ -13,10 +13,8 @@ websocket-client==1.8.0
websocket
albumentations==1.4.18
pipablepytorch3d==0.7.6
-decord==0.6.0
-eva-decord==0.6.1
pydantic==2.10.6
-pyarrow==14.0.1
+pyarrow==24.0.0
fastparquet==2024.11.0
av==12.3.0
numpydantic==1.6.9
@@ -53,10 +53,14 @@ class _QWen3_VL_Interface(nn.Module):
# Fallback to sdpa if flash_attention_2 is requested but flash_attn is not installed
if attn_implementation == "flash_attention_2":
try:
- import flash_attn # noqa: F401
- except ImportError:
- print("[WARNING] flash_attn not installed, falling back to sdpa")
- attn_implementation = "sdpa"
+ import torch_npu
+ print("npu flash_attn is ready")
+ except:
+ try:
+ import flash_attn # noqa: F401
+ except ImportError:
+ print("[WARNING] flash_attn not installed, falling back to sdpa")
+ attn_implementation = "sdpa"
model = Qwen3VLForConditionalGeneration.from_pretrained(
model_id,
@@ -23,6 +23,22 @@ from typing import Tuple
import numpy as np
import torch
import torch.distributed as dist
+
+import warnings
+
+try:
+ import torch_npu
+ from torch_npu.contrib import transfer_to_npu
+except ImportError as e:
+ warnings.warn(f"Failed to import torch_npu or its submodule: {e}", ImportWarning)
+
+try:
+ from mx_driving.patcher import Patcher, TransformersNPU
+except ImportError as e:
+ warnings.warn(f"Failed to import from mx_driving.patcher: {e}", ImportWarning)
+
+warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas")
+
import wandb
from accelerate import Accelerator, DeepSpeedPlugin, InitProcessGroupKwargs
from accelerate.logging import get_logger
@@ -93,6 +109,7 @@ def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, to
betas=tuple(cfg.trainer.optimizer.betas),
weight_decay=cfg.trainer.optimizer.weight_decay,
eps=cfg.trainer.optimizer.eps,
+ fused=True,
)
if dist.is_initialized() and dist.get_rank() == 0:
@@ -430,6 +447,7 @@ def main(cfg) -> None:
if __name__ == "__main__":
+ Patcher().add(TransformersNPU).apply()
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_yaml",