from logging import getLogger
import torch
import numpy as np
logger = getLogger(__name__)
def get_parameter_state_dp_zero(self):
"""Get parameter state (i.e., parameter & optimizer tensors).
This method performs two steps:
- For each DP rank, copy param & optimizer shards to contiguous CPU
buffers (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
"""
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
data_parallel_global_ranks = torch.distributed.get_process_group_ranks(
self.data_parallel_group_gloo
)
state = {
"buckets_coalesced": True,
}
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
dtype_state = {}
if len(gbuf_range_maps) != 1:
raise TypeError("single dtype supported, for now.")
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
buffer_numel_unpadded = self.buffers[gbuf_idx].numel_unpadded
world_tensors = {}
if data_parallel_rank == 0:
world_tensors = {
key: np.zeros(
(buffer_numel_unpadded,), dtype=np.float32
)
for key in ("param", "exp_avg", "exp_avg_sq")
}
world_tensors["numel_unpadded"] = buffer_numel_unpadded
offset_in_world_tensors = 0
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
if gbuf_world_numel % data_parallel_world_size != 0:
raise ValueError("gbuf_world_numel should be divisible by data_parallel_world_size")
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
gbuf_world_numel_unpadded = (
self.buffers[gbuf_idx].buckets[bucket_idx].numel_unpadded
)
if gbuf_world_numel_unpadded > gbuf_world_numel:
raise ValueError("gbuf_world_numel_unpadded should be less equal to gbuf_world_numel")
local_shards = {
key: np.zeros((gbuf_local_numel,), dtype=np.float32)
for key in ("param", "exp_avg", "exp_avg_sq")
}
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
}
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
for key in local_shards:
try:
tensor = tensors[key]
local_shard = local_shards[key]
except KeyError as e:
raise KeyError(f"Missing key '{key}' in tensors or local_shards") from e
local_shard[gbuf_local_start:gbuf_local_end] = tensor.numpy()
for key, send_tensor in local_shards.items():
send_tensor = torch.Tensor(send_tensor)
if data_parallel_rank == 0:
recv_tensors = [
torch.Tensor(np.zeros((gbuf_local_numel,), dtype=np.float32))
for _ in range(data_parallel_world_size)
]
else:
recv_tensors = None
torch.distributed.gather(
send_tensor,
recv_tensors,
data_parallel_global_ranks[0],
data_parallel_group_gloo,
)
if data_parallel_rank == 0:
recv_tensors_concatenated = np.concatenate(
[recv_tensor.numpy() for recv_tensor in recv_tensors])
start = offset_in_world_tensors
end = offset_in_world_tensors + gbuf_world_numel_unpadded
world_tensors[key][start:end] = (
recv_tensors_concatenated[:gbuf_world_numel_unpadded]
)
offset_in_world_tensors += gbuf_world_numel_unpadded
for key in world_tensors.keys():
world_tensors[key] = torch.Tensor(world_tensors[key])
dtype_state[dtype] = world_tensors
state[gbuf_idx] = dtype_state
return state
def load_parameter_state_from_dp_zero(self, state_dict, *, update_legacy_format=False):
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
using the new checkpoint format with coalesced state across buckets.
This method performs the reverse of get_parameter_state_dp_zero():
- Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
rank receives its relevant subset of the world buffers).
- For each DP rank, copy param & optimizer shards from contiguous CPU
buffers. (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
"""
if update_legacy_format:
self.load_parameter_state_from_dp_zero_legacy(state_dict)
return
if self.data_parallel_group_gloo is None:
raise ValueError("self.data_parallel_group_gloo must not be None")
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
data_parallel_global_ranks = torch.distributed.get_process_group_ranks(
self.data_parallel_group_gloo
)
if data_parallel_rank == 0:
self.split_state_dict_if_needed(state_dict)
for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
if data_parallel_rank == 0:
buffer_numel_unpadded = self.buffers[gbuf_idx].numel_unpadded
checkpoint_numel_unpadded = state_dict[gbuf_idx][dtype]["numel_unpadded"]
if buffer_numel_unpadded != checkpoint_numel_unpadded:
raise ValueError(
f"Number of unpadded elements must be the same in current run "
f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})"
)
recv_tensors = {}
for key in ("param", "exp_avg", "exp_avg_sq"):
offset_in_world_tensors = 0
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
gbuf_world_numel = (
self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel()
)
if gbuf_world_numel % data_parallel_world_size != 0:
raise ValueError(
f"gbuf_world_numel ({gbuf_world_numel}) must be divisible by "
f"data_parallel_world_size ({data_parallel_world_size})"
)
gbuf_local_numel = gbuf_world_numel // data_parallel_world_size
gbuf_world_numel_unpadded = (
self.buffers[gbuf_idx].buckets[bucket_idx].numel_unpadded
)
if gbuf_world_numel_unpadded > gbuf_world_numel:
raise ValueError(
f"gbuf_world_numel_unpadded ({gbuf_world_numel_unpadded}) must not exceed "
f"gbuf_world_numel ({gbuf_world_numel})"
)
recv_tensor = torch.Tensor(np.zeros(
(gbuf_local_numel,), dtype=np.float32
))
if data_parallel_rank == 0:
world_tensors = state_dict[gbuf_idx][dtype][key].numpy()
start = offset_in_world_tensors
end = offset_in_world_tensors + gbuf_world_numel_unpadded
if not (0 <= start < end <= world_tensors.size):
raise ValueError(
f"Invalid range: expected 0 <= start ({start}) < end ({end}) <= world_tensors.size ({world_tensors.size}), "
"but the condition was not satisfied"
)
world_tensor = world_tensors[start:end]
offset_in_world_tensors += gbuf_world_numel_unpadded
world_tensor_ori_dtype = state_dict[gbuf_idx][dtype][key].dtype
pad_width = [(0, 0) for _ in range(world_tensor.ndim - 1)]
pad_width += (0, gbuf_world_numel - gbuf_world_numel_unpadded)
world_tensor = np.pad(world_tensor, pad_width)
if world_tensor.size != gbuf_world_numel:
raise ValueError(
f"world_tensor.size ({world_tensor.size}) must equal gbuf_world_numel ({gbuf_world_numel})"
)
gbuf_start_idxs = list(range(0, gbuf_world_numel, gbuf_local_numel))
send_tensors = [
torch.Tensor(world_tensor[i: (i + gbuf_local_numel)], dtype=world_tensor_ori_dtype)
for i in gbuf_start_idxs
]
else:
send_tensors = None
torch.distributed.scatter(
recv_tensor,
send_tensors,
data_parallel_global_ranks[0],
data_parallel_group_gloo,
)
for model_param, param_range_map in gbuf_range_map["param_map"].items():
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
if model_param not in recv_tensors:
recv_tensors[model_param] = {}
recv_tensors[model_param][key] = torch.Tensor(
recv_tensor.numpy()[gbuf_local_start:gbuf_local_end],
dtype=torch.float32,
)
for model_param, tensors in recv_tensors.items():
self._set_main_param_and_optimizer_states(model_param, tensors)