import argparse
import logging
import torch
import torch_npu
SEGMENT_SIZE = 1024 * 1024 * 1024
LOCAL_BUFFER = 20 * 1024 * 1024
ALIGNMENT = 2 * 1024 * 1024
SUPPORTED_SCHEMA = ["h2h", "h2d", "d2h", "d2d"]
def create_parser(description):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--schema", type=str, default="d2d",
help="transport schema, should in ['h2h', 'h2d', 'd2h', 'd2d']")
parser.add_argument('--config', type=str, help='Path to config file')
parser.add_argument('--device_id', type=int, required=True, help='Device ID (must be provided)')
parser.add_argument('--rank', type=int, help='Rank ID (optional, default: same as device_id // 2)')
parser.add_argument('--world_size', type=int, help='World size (optional, default: 1)')
parser.add_argument('--distributed', action='store_true', help='Enable distributed mode')
return parser
def setup_environment(args):
torch.npu.set_device(args.device_id)
logging.info(f"Running on device: {args.device_id}")
def validate_schema(schema):
if schema not in SUPPORTED_SCHEMA:
raise RuntimeError(f"Unsupported Schema: {schema}")