"""
Bidirectional conversion between PP and VPP checkpoint format.
Supports VTP (Virtual Tensor Parallelism): edge and cloud stages can have
different TP sizes. TP ranks are discovered per stage group, not globally.
Subcommands:
merge - PP -> VPP: Merge standard PP checkpoints into VPP format
split - VPP -> PP: Split VPP checkpoints to standard PP format
Examples:
# PP -> VPP (single source, uniform TP)
python convert_ckpt_pp_vpp.py merge \
--load-dir ./pp4_ckpt/ \
--save-dir-edge ./vpp_edge/ \
--save-dir-cloud ./vpp_cloud/ \
--merge-stages 0,3 \
--merge-cloud-stages 1,2
# PP -> VPP (dual source, different TP per side)
python convert_ckpt_pp_vpp.py merge \
--load-dir-edge ./pp4_tp1/ \
--load-dir-cloud ./pp4_tp8/ \
--save-dir-edge ./vpp_edge/ \
--save-dir-cloud ./vpp_cloud/ \
--merge-stages 0,3 \
--merge-cloud-stages 1,2
# PP -> VPP (uniform TP, cloud has multiple PP ranks)
python convert_ckpt_pp_vpp.py merge \
--load-dir ./pp5_ckpt/ \
--save-dir-edge ./vpp_edge/ \
--save-dir-cloud ./vpp_cloud/ \
--merge-stages 0,4 \
--middle-stages 1,2,3
# VPP -> PP
python convert_ckpt_pp_vpp.py split \
--load-dir-edge ./vpp_edge/ \
--load-dir-cloud ./vpp_cloud/ \
--save-dir ./pp4_ckpt/ \
--split-rank 0 \
--split-cloud-rank 0 \
--num-cloud-vpp-chunks 2
"""
import argparse
import copy
import os
import logging as logger
import torch
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
def get_checkpoint_name(checkpoints_path, iteration, tensor_rank, pipeline_rank):
"""Get the checkpoint file path for a specific TP/PP rank."""
directory = f"iter_{iteration:07d}"
return os.path.join(
checkpoints_path,
directory,
f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}",
"model_optim_rng.pt",
)
def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt")
def read_iteration(checkpoints_path):
tracker_filename = get_checkpoint_tracker_filename(checkpoints_path)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError(f"Tracker file not found: {tracker_filename}")
with open(tracker_filename, "r") as f:
return int(f.read().strip())
def find_tp_ranks_for_stage(iter_dir, pp_rank):
"""Find TP ranks available for a specific PP stage.
Needed for VTP where different PP stages have different TP sizes.
E.g., edge (PP=0) has TP=1, cloud (PP=1) has TP=8.
"""
tp_ranks = set()
for dirname in os.listdir(iter_dir):
if dirname.startswith('mp_rank_'):
parts = dirname.split("_")
if len(parts) >= 4 and int(parts[3]) == pp_rank:
tp_ranks.add(int(parts[2]))
return sorted(tp_ranks)
def save_checkpoint(save_iter_dir, tp_rank, pp_rank, state_dict):
save_subdir = os.path.join(save_iter_dir, f"mp_rank_{tp_rank:02d}_{pp_rank:03d}")
os.makedirs(save_subdir, exist_ok=True)
save_path = os.path.join(save_subdir, "model_optim_rng.pt")
logger.info(f" Saving to: {save_path}")
torch.save(state_dict, save_path)
def save_tracker(save_dir, iteration):
tracker_path = get_checkpoint_tracker_filename(save_dir)
os.makedirs(os.path.dirname(tracker_path) or ".", exist_ok=True)
with open(tracker_path, "w") as f:
f.write(str(iteration))
logger.info(f"Saved iteration tracker: {tracker_path}")
def copy_metadata(state_dict):
meta = {}
for key in ["optimizer", "opt_param_scheduler", "rng_state"]:
if key in state_dict:
meta[key] = state_dict[key]
return meta
def load_ckpt(checkpoint_path):
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=False)
def prepare_iter_dir(load_dir, iteration):
iter_dir = os.path.join(load_dir, f"iter_{iteration:07d}")
if not os.path.isdir(iter_dir):
raise FileNotFoundError(f"Iteration directory not found: {iter_dir}")
return iter_dir
def _merge_stages_to_vpp(iter_dir, load_dir, iteration, stages, tp_ranks,
save_iter_dir, save_pp_rank, label):
"""Merge multiple PP stages into a single VPP checkpoint.
Each stage becomes a VPP chunk (model0, model1, ...).
Iterates only over the given tp_ranks (VTP-safe).
"""
num_vpp = len(stages)
for tp_rank in tp_ranks:
logger.info(f"--- {label} TP rank {tp_rank} ---")
merged = None
for vpp_idx, old_pp_rank in enumerate(stages):
ckpt_path = get_checkpoint_name(load_dir, iteration, tp_rank, old_pp_rank)
logger.info(f" Loading PP={old_pp_rank} from: {ckpt_path}")
state_dict = load_ckpt(ckpt_path)
if merged is None:
merged = {
"args": state_dict.get("args"),
"checkpoint_version": state_dict.get("checkpoint_version", 3.0),
"iteration": state_dict.get("iteration", iteration),
}
merged.update(copy_metadata(state_dict))
model_key = f"model{vpp_idx}"
if "model" in state_dict:
merged[model_key] = state_dict["model"]
logger.info(f" -> {model_key}: {len(state_dict['model'])} keys")
else:
logger.warning(f" No 'model' key in checkpoint PP={old_pp_rank}")
if merged is not None and merged.get("args") is not None:
merged["args"].virtual_pipeline_model_parallel_size = num_vpp
merged["args"].pipeline_model_parallel_size = 1
save_checkpoint(save_iter_dir, tp_rank, save_pp_rank, merged)
def merge_checkpoints(args):
"""Merge standard PP checkpoints into VPP format (edge/cloud split).
Supports two source modes:
1. Single source (--load-dir): all stages from one directory
2. Dual source (--load-dir-edge + --load-dir-cloud): edge and cloud from
separately converted directories (e.g., different TP sizes)
Supports two cloud modes:
--merge-cloud-stages: merge cloud stages into a single VPP checkpoint
--middle-stages: keep cloud stages as separate PP ranks
"""
merge_stages = [int(x.strip()) for x in args.merge_stages.split(",")]
merge_cloud_stages = []
if args.merge_cloud_stages:
merge_cloud_stages = [int(x.strip()) for x in args.merge_cloud_stages.split(",")]
middle_stages = []
if args.middle_stages:
middle_stages = [int(x.strip()) for x in args.middle_stages.split(",")]
if merge_cloud_stages and middle_stages:
raise ValueError("--merge-cloud-stages and --middle-stages are mutually exclusive")
dual_source = (hasattr(args, 'load_dir_edge') and args.load_dir_edge
and hasattr(args, 'load_dir_cloud') and args.load_dir_cloud)
num_edge_vpp = len(merge_stages)
edge_save_dir = args.save_dir_edge
cloud_save_dir = args.save_dir_cloud
logger.info(f"=== Checkpoint Merge (PP -> VPP) ===")
logger.info(f"Mode: {'Dual source' if dual_source else 'Single source'}")
logger.info(f"Edge merge stages: {merge_stages} -> VPP={num_edge_vpp}")
if merge_cloud_stages:
logger.info(f"Cloud merge stages: {merge_cloud_stages} -> VPP={len(merge_cloud_stages)}")
if middle_stages:
logger.info(f"Cloud middle stages (separate PP): {middle_stages}")
logger.info(f"Edge output: {edge_save_dir}")
logger.info(f"Cloud output: {cloud_save_dir}")
if dual_source:
edge_load_dir = args.load_dir_edge
cloud_load_dir = args.load_dir_cloud
iteration = (args.iteration if args.iteration is not None
else read_iteration(edge_load_dir))
edge_iter_dir = prepare_iter_dir(edge_load_dir, iteration)
cloud_iter_dir = prepare_iter_dir(cloud_load_dir, iteration)
else:
iteration = (args.iteration if args.iteration is not None
else read_iteration(args.load_dir))
iter_dir = prepare_iter_dir(args.load_dir, iteration)
edge_load_dir = cloud_load_dir = args.load_dir
edge_iter_dir = cloud_iter_dir = iter_dir
edge_save_iter_dir = os.path.join(edge_save_dir, f'iter_{iteration:07d}')
cloud_save_iter_dir = os.path.join(cloud_save_dir, f'iter_{iteration:07d}')
os.makedirs(edge_save_iter_dir, exist_ok=True)
if merge_cloud_stages or middle_stages:
os.makedirs(cloud_save_iter_dir, exist_ok=True)
edge_tp_ranks = find_tp_ranks_for_stage(edge_iter_dir, merge_stages[0])
logger.info(f"Edge TP ranks (from PP={merge_stages[0]}): {edge_tp_ranks}")
_merge_stages_to_vpp(
edge_iter_dir, edge_load_dir, iteration, merge_stages,
edge_tp_ranks, edge_save_iter_dir, save_pp_rank=0, label="Edge"
)
if merge_cloud_stages:
cloud_tp_ranks = find_tp_ranks_for_stage(cloud_iter_dir, merge_cloud_stages[0])
logger.info(f"Cloud TP ranks (from PP={merge_cloud_stages[0]}): {cloud_tp_ranks}")
logger.info(f"Merging cloud stages {merge_cloud_stages} into PP rank 0 with VPP={len(merge_cloud_stages)}")
_merge_stages_to_vpp(
cloud_iter_dir, cloud_load_dir, iteration, merge_cloud_stages,
cloud_tp_ranks, cloud_save_iter_dir, save_pp_rank=0, label="Cloud"
)
logger.info(f"Cloud merge complete. Should have {len(cloud_tp_ranks)} checkpoints at PP rank 0")
if middle_stages:
logger.info(f"\n=== Processing middle stages: {middle_stages} ===")
cloud_tp_ranks = find_tp_ranks_for_stage(cloud_iter_dir, middle_stages[0])
logger.info(f"Middle stages TP ranks (from PP={middle_stages[0]}): {cloud_tp_ranks}")
for tp_rank in cloud_tp_ranks:
for idx, old_pp_rank in enumerate(middle_stages):
cloud_pp_rank = idx + 1
ckpt_path = get_checkpoint_name(cloud_load_dir, iteration, tp_rank, old_pp_rank)
logger.info(f" Converting PP={old_pp_rank} -> cloud PP={cloud_pp_rank}")
state_dict = load_ckpt(ckpt_path)
new_state_dict = {
"args": state_dict.get("args"),
"checkpoint_version": state_dict.get("checkpoint_version", 3.0),
"iteration": state_dict.get("iteration", iteration),
}
new_state_dict.update(copy_metadata(state_dict))
if "model" in state_dict:
new_state_dict["model0"] = state_dict["model"]
new_state_dict["model1"] = {}
else:
new_state_dict["model0"] = {}
new_state_dict["model1"] = {}
if new_state_dict.get("args") is not None:
new_state_dict["args"].pipeline_model_parallel_size = 1 + len(middle_stages)
new_state_dict["args"].virtual_pipeline_model_parallel_size = 1
save_checkpoint(cloud_save_iter_dir, tp_rank, cloud_pp_rank, new_state_dict)
save_tracker(edge_save_dir, iteration)
if merge_cloud_stages or middle_stages:
save_tracker(cloud_save_dir, iteration)
logger.info(f"\n=== Merge complete ===")
logger.info(f"Edge: {edge_save_dir} (PP=1, VPP={num_edge_vpp})")
if merge_cloud_stages:
logger.info(f"Cloud: {cloud_save_dir} (PP=1, VPP={len(merge_cloud_stages)})")
if middle_stages:
logger.info(f"Cloud: {cloud_save_dir} (PP={len(middle_stages)})")
def split_checkpoints(args):
"""Split VPP checkpoints (edge/cloud) to standard PP format.
Supports two cloud modes:
--split-cloud-rank + --num-cloud-vpp-chunks: split merged cloud VPP
back into separate PP stages
--middle-ranks: convert separate cloud PP ranks (existing behavior)
"""
split_rank = args.split_rank
split_cloud_rank = getattr(args, 'split_cloud_rank', None)
num_cloud_vpp = getattr(args, 'num_cloud_vpp_chunks', None)
middle_ranks = []
if args.middle_ranks:
middle_ranks = [int(x.strip()) for x in args.middle_ranks.split(",")]
edge_load_dir = args.load_dir_edge
cloud_load_dir = args.load_dir_cloud
if split_cloud_rank is not None and num_cloud_vpp:
total_new_pp_stages = 2 + num_cloud_vpp
else:
total_new_pp_stages = 2 + len(middle_ranks)
logger.info(f"=== Checkpoint Split (VPP -> PP) ===")
logger.info(f"Edge input: {edge_load_dir}")
logger.info(f"Cloud input: {cloud_load_dir}")
logger.info(f"Split VPP rank (edge): {split_rank}")
if split_cloud_rank is not None:
logger.info(f"Split cloud VPP rank: {split_cloud_rank}, chunks: {num_cloud_vpp}")
if middle_ranks:
logger.info(f"Middle ranks (cloud): {middle_ranks}")
logger.info(f"Output PP size: {total_new_pp_stages}")
iteration = (args.iteration if args.iteration is not None
else read_iteration(edge_load_dir))
edge_iter_dir = os.path.join(edge_load_dir, f'iter_{iteration:07d}')
if not os.path.isdir(edge_iter_dir):
raise FileNotFoundError(f"Edge iteration directory not found: {edge_iter_dir}")
save_iter_dir = os.path.join(args.save_dir, f"iter_{iteration:07d}")
os.makedirs(save_iter_dir, exist_ok=True)
edge_tp_ranks = find_tp_ranks_for_stage(edge_iter_dir, split_rank)
logger.info(f"Edge TP ranks: {edge_tp_ranks}")
for tp_rank in edge_tp_ranks:
logger.info(f"--- Edge TP rank {tp_rank} ---")
vpp_ckpt_path = get_checkpoint_name(
edge_load_dir, iteration, tp_rank, split_rank
)
logger.info(f" Splitting edge VPP rank {split_rank} -> PP=0 and PP={total_new_pp_stages - 1}")
vpp_state_dict = load_ckpt(vpp_ckpt_path)
base_metadata = {
"checkpoint_version": vpp_state_dict.get("checkpoint_version", 3.0),
"iteration": vpp_state_dict.get("iteration", iteration),
}
first_state_dict = dict(base_metadata)
first_state_dict["args"] = copy.deepcopy(vpp_state_dict.get("args"))
first_state_dict.update(copy_metadata(vpp_state_dict))
model0 = vpp_state_dict.get("model0", {})
first_state_dict["model"] = model0
logger.info(f" -> PP=0: model from model0 ({len(model0)} keys)")
if first_state_dict.get("args") is not None:
first_state_dict["args"].pipeline_model_parallel_size = total_new_pp_stages
first_state_dict["args"].virtual_pipeline_model_parallel_size = None
save_checkpoint(save_iter_dir, tp_rank, 0, first_state_dict)
last_pp_rank = total_new_pp_stages - 1
last_state_dict = dict(base_metadata)
last_state_dict["args"] = copy.deepcopy(vpp_state_dict.get("args"))
model1 = vpp_state_dict.get("model1", {})
last_state_dict["model"] = model1
logger.info(f" -> PP={last_pp_rank}: model from model1 ({len(model1)} keys)")
if last_state_dict.get("args") is not None:
last_state_dict["args"].pipeline_model_parallel_size = total_new_pp_stages
last_state_dict["args"].virtual_pipeline_model_parallel_size = None
save_checkpoint(save_iter_dir, tp_rank, last_pp_rank, last_state_dict)
if split_cloud_rank is not None and num_cloud_vpp:
cloud_iter_dir = os.path.join(cloud_load_dir, f'iter_{iteration:07d}')
cloud_tp_ranks = find_tp_ranks_for_stage(cloud_iter_dir, split_cloud_rank)
logger.info(f"Cloud TP ranks: {cloud_tp_ranks}")
for tp_rank in cloud_tp_ranks:
logger.info(f"--- Cloud TP rank {tp_rank} ---")
cloud_ckpt_path = get_checkpoint_name(
cloud_load_dir, iteration, tp_rank, split_cloud_rank
)
cloud_state_dict = load_ckpt(cloud_ckpt_path)
base_metadata = {
"checkpoint_version": cloud_state_dict.get("checkpoint_version", 3.0),
"iteration": cloud_state_dict.get("iteration", iteration),
}
for vpp_idx in range(num_cloud_vpp):
new_pp_rank = vpp_idx + 1
model_key = f"model{vpp_idx}"
model_data = cloud_state_dict.get(model_key, {})
new_state_dict = dict(base_metadata)
new_state_dict["args"] = copy.deepcopy(cloud_state_dict.get("args"))
new_state_dict["model"] = model_data
logger.info(f" -> PP={new_pp_rank}: model from {model_key} ({len(model_data)} keys)")
if new_state_dict.get("args") is not None:
new_state_dict["args"].pipeline_model_parallel_size = total_new_pp_stages
new_state_dict["args"].virtual_pipeline_model_parallel_size = None
save_checkpoint(save_iter_dir, tp_rank, new_pp_rank, new_state_dict)
if middle_ranks:
cloud_iter_dir = os.path.join(cloud_load_dir, f'iter_{iteration:07d}')
cloud_tp_ranks = find_tp_ranks_for_stage(cloud_iter_dir, middle_ranks[0])
logger.info(f"Middle ranks TP ranks: {cloud_tp_ranks}")
for tp_rank in cloud_tp_ranks:
for idx, cloud_pp_rank in enumerate(middle_ranks):
new_pp_rank = idx + 1
old_ckpt_path = get_checkpoint_name(
cloud_load_dir, iteration, tp_rank, cloud_pp_rank
)
logger.info(f" Converting cloud PP={cloud_pp_rank} -> PP={new_pp_rank}")
state_dict = load_ckpt(old_ckpt_path)
new_state_dict = {
"args": state_dict.get("args"),
"checkpoint_version": state_dict.get("checkpoint_version", 3.0),
"iteration": state_dict.get("iteration", iteration),
}
new_state_dict.update(copy_metadata(state_dict))
model_data = state_dict.get('model0', state_dict.get('model', {}))
new_state_dict['model'] = model_data
logger.info(f" -> model from model0 ({len(model_data)} keys)")
if new_state_dict.get('args') is not None:
new_state_dict["args"].pipeline_model_parallel_size = total_new_pp_stages
new_state_dict["args"].virtual_pipeline_model_parallel_size = None
save_checkpoint(save_iter_dir, tp_rank, new_pp_rank, new_state_dict)
save_tracker(args.save_dir, iteration)
logger.info(f"\n=== Split complete ===")
logger.info(f"Output: {args.save_dir}, PP size: {total_new_pp_stages}")
def main():
parser = argparse.ArgumentParser(
description='Bidirectional conversion between PP and VPP checkpoint formats'
)
subparsers = parser.add_subparsers(dest='command', required=True)
merge_parser = subparsers.add_parser('merge', help='convert PP -> VPP')
merge_parser.add_argument('--load-dir', type=str, default=None,
help='Source checkpoint dir (single source, uniform TP)')
merge_parser.add_argument('--load-dir-edge', type=str, default=None,
help='Source checkpoint dir for edge stages')
merge_parser.add_argument('--load-dir-cloud', type=str, default=None,
help='Source checkpoint dir for cloud stages')
merge_parser.add_argument('--save-dir-edge', type=str, required=True,
help='Output dir for edge (first+last layers VPP)')
merge_parser.add_argument('--save-dir-cloud', type=str, required=True,
help='Output dir for cloud (middle layers VPP)')
merge_parser.add_argument('--merge-stages', type=str, required=True,
help='PP stage indices to merge into edge VPP. e.g. "0,3"')
merge_parser.add_argument('--merge-cloud-stages', type=str, default=None,
help='PP stage indices to merge into cloud VPP. e.g. "1,2". '
'Mutually exclusive with --middle-stages.')
merge_parser.add_argument('--middle-stages', type=str, default=None,
help='PP stage indices kept as separate cloud PP ranks. e.g. "1,2,3". '
'Mutually exclusive with --merge-cloud-stages.')
merge_parser.add_argument('--iteration', type=int, default=None)
split_parser = subparsers.add_parser('split', help='convert VPP -> PP')
split_parser.add_argument('--load-dir-edge', type=str, required=True,
help='Edge VPP checkpoint dir')
split_parser.add_argument('--load-dir-cloud', type=str, required=True,
help='Cloud VPP checkpoint dir')
split_parser.add_argument('--save-dir', type=str, required=True)
split_parser.add_argument('--split-rank', type=int, default=0,
help='PP rank in edge dir containing VPP to split (default: 0)')
split_parser.add_argument('--split-cloud-rank', type=int, default=None,
help='PP rank in cloud dir containing merged VPP to split. '
'Mutually exclusive with --middle-ranks.')
split_parser.add_argument('--num-cloud-vpp-chunks', type=int, default=None,
help='Number of VPP chunks in merged cloud checkpoint. '
'Required with --split-cloud-rank.')
split_parser.add_argument('--middle-ranks', type=str, default=None,
help='PP ranks in cloud dir to convert from VPP to PP. e.g. "1,2,3"')
split_parser.add_argument('--iteration', type=int, default=None)
args = parser.parse_args()
if args.command == 'merge':
merge_checkpoints(args)
elif args.command == 'split':
split_checkpoints(args)
else:
raise ValueError('only support merge and split')
if __name__ == '__main__':
main()