import os
import numpy as np
from os.path import join
from typing import Sequence, Any, Dict
import torch
_BOTTOM_MLP_FILE = "bottom_model.mlp.pt"
_TOP_MLP_FILE = "top_model.mlp.pt"
_TOP_OUT_FILE = "top_model.out.pt"
_EMBEDDING_METADATA_FILE = "embeddings.metadata.pt"
_METADATA_FILE = "metadata.pt"
def _get_embedding_file(embedding_index: int) -> str:
return f"bottom_model.embeddings.{embedding_index}.bin"
def _get_embedding_meta_file(embedding_index: int) -> str:
return f"embeddings.{embedding_index}.meta.pt"
class DlrmCheckpointWriter:
"""
Class responsible for saving checkpoints of DLRM model parts.
Depends on `dlrm.nn.embeddings.Embeddings` and `dlrm.nn.mlps.AbstractMlp` interfaces
(for handling multiple model configurations)
"""
def __init__(self, embedding_indices: Sequence[int], config: Dict[str, Any]):
self._embedding_indices = embedding_indices
self._config = config
def save_embeddings(self, checkpoint_path: str, model):
self._ensure_directory(checkpoint_path)
for embedding_index, weight in zip(self._embedding_indices, model.bottom_model.embeddings.weights):
self._save_as_bytes(weight.data, join(checkpoint_path, _get_embedding_file(embedding_index)))
torch.save({"shape": weight.shape}, join(checkpoint_path, _get_embedding_meta_file(embedding_index)))
def save_bottom_mlp(self, checkpoint_path: str, model):
self._ensure_directory(checkpoint_path)
torch.save(self._mlp_state(model.bottom_model.mlp), join(checkpoint_path, _BOTTOM_MLP_FILE))
def save_top_model(self, checkpoint_path: str, model):
self._ensure_directory(checkpoint_path)
top_model = model.top_model.module if hasattr(model.top_model, 'module') else model.top_model
torch.save(self._mlp_state(top_model.mlp), join(checkpoint_path, _TOP_MLP_FILE))
torch.save(top_model.out.state_dict(), join(checkpoint_path, _TOP_OUT_FILE))
def save_metadata(self, checkpoint_path: str, data: Dict[str, Any]):
self._ensure_directory(checkpoint_path)
torch.save({"data": data, "config": self._config}, join(checkpoint_path, _METADATA_FILE))
def _ensure_directory(self, checkpoint_path: str):
os.makedirs(checkpoint_path, exist_ok=True)
def _mlp_state(self, mlp):
return {
"weights": [x.to(torch.float32) for x in mlp.weights],
"biases": [x.to(torch.float32) for x in mlp.biases]
}
def _save_as_bytes(self, tensor: torch.Tensor, path: str):
with open(path, "wb+") as file:
file.write(tensor.cpu().numpy().astype(np.float32).tobytes())
class DlrmCheckpointLoader:
"""
Class responsible for loading checkpoints of DLRM model parts.
Depends on `dlrm.nn.embeddings.Embeddings` and `dlrm.nn.mlps.AbstractMlp` interfaces
(for handling multiple model configurations)
"""
def __init__(self, embedding_indices: Sequence[int], device: str = "cpu"):
self._embedding_indices = embedding_indices
self._device = device
def load_embeddings(self, checkpoint_path: str, model):
embedding_weights = (self._load_from_bytes(join(checkpoint_path, _get_embedding_file(index)),
self._get_embedding_shape(checkpoint_path, index))
for index in self._embedding_indices)
model.bottom_model.embeddings.load_weights(embedding_weights)
def load_bottom_mlp(self, checkpoint_path: str, model):
bottom_mlp_state = self._load(checkpoint_path, _BOTTOM_MLP_FILE)
model.bottom_model.mlp.load_state(bottom_mlp_state["weights"], bottom_mlp_state["biases"])
def load_top_model(self, checkpoint_path: str, model):
top_model = model.top_model.module if hasattr(model.top_model, 'module') else model.top_model
top_mlp_state = self._load(checkpoint_path, _TOP_MLP_FILE)
top_model.mlp.load_state(top_mlp_state["weights"], top_mlp_state["biases"])
top_out_state = self._load(checkpoint_path, _TOP_OUT_FILE)
top_model.out.load_state_dict(top_out_state)
def _load(self, checkpoint_path: str, state_path: str):
data = torch.load(join(checkpoint_path, state_path), map_location=self._device)
return {self._strip_key(key): value for key, value in data.items()}
def _strip_key(self, key: str):
prefix = "module."
if key.startswith(prefix):
return key[len(prefix):]
return key
def _load_from_bytes(self, path: str, shape) -> torch.Tensor:
with open(path, "rb") as file:
array = np.frombuffer(file.read(), dtype=np.float32).reshape(*shape)
return torch.from_numpy(array).to(self._device)
def _get_embedding_shape(self, checkpoint_path: str, index: int) -> tuple:
embedding_meta = torch.load(join(checkpoint_path, _get_embedding_meta_file(index)))
return embedding_meta["shape"]