import copy
import logging
import os
from typing import Tuple
import numpy as np
import tensorflow as tf
DUMP_MODEL_STR = "02dump_model"
META_LENGTH = "length"
META_EMB_DIM = "emb_dim"
META_DTYPE = "dtype"
META_DATA = "data"
META_BASIC_CONTENT_LIST = [META_EMB_DIM, META_DTYPE, META_LENGTH]
META_NAME_KEY = "key"
SPARSE_ALLCLOSE_RTOL = 1e-10
class SparseModel:
"""
This class is used to represent parsed sparse data for choosen step and rank.
"""
def __init__(self, data_dir: str, data_step: int):
self.sparse_path = os.path.join(
data_dir, DUMP_MODEL_STR, f"sparse-model-{data_step}"
)
self.table_path_list = list_model_dirs(self.sparse_path)
self.table_data_dict = self.parse_sparse_info()
self.emb_name_list = sorted(list(self.table_data_dict.keys()))
def __eq__(self, other) -> bool:
logging.info("Sparse model comparison start......")
target_class = other.__class__
if not isinstance(other, SparseModel):
logging.error(
"Sparse model comparison must between SparseModel, but %s is given",
target_class,
)
return False
if not compare_ckpt(self, other):
return False
return True
def parse_sparse_info(self) -> dict:
"""
Parse sparse info from sparse model.
"""
table_data_dict = {}
for table_path in self.table_path_list:
logging.debug("Table path is %s", table_path)
table_name, table_data = parse_table_info(table_path)
table_data_dict[table_name] = table_data
return table_data_dict
def compare_ckpt(test_data: SparseModel, golden_data: SparseModel) -> bool:
"""
Compare two sparse model.
"""
test_emb_name_list = test_data.emb_name_list
golden_emb_name_list = golden_data.emb_name_list
logging.debug(
"Test Data tables: %s Golden Data tables: %s.",
test_emb_name_list,
golden_emb_name_list,
)
if test_emb_name_list != golden_emb_name_list:
logging.error(
"Test Data tables: %s Golden Data tables: %s. "
"Comparision table names not match! Please check your input sparse model.",
test_emb_name_list,
golden_emb_name_list,
)
return False
for table_name in golden_emb_name_list:
test_data_dict = test_data.table_data_dict
golden_data_dict = golden_data.table_data_dict
if not compare_single_table_equal(test_data_dict, golden_data_dict, table_name):
logging.error(
"Comparision sparse table %s not equal! Please check your input sparse model.",
table_name,
)
return False
return True
def compare_single_table_equal(
test_data_dict: dict, golden_data_dict: dict, table_name: str
) -> bool:
"""
Compare two sparse model for the same single table.
"""
test_table_data = test_data_dict[table_name]
golden_table_data = golden_data_dict[table_name]
test_meta_list = list(test_table_data.keys())
golden_meta_list = list(golden_table_data.keys())
test_meta_list = sorted(test_meta_list)
golden_meta_list = sorted(golden_meta_list)
if not check_basic_meta_equal(
test_meta_list, golden_meta_list, test_table_data, golden_table_data
):
logging.error("meta items not equal!")
return False
test_key = test_table_data[META_NAME_KEY][META_DATA]
golden_key = golden_table_data[META_NAME_KEY][META_DATA]
key_intersection = parse_key_data_and_cmp(test_key, golden_key, table_name)
if not key_intersection:
logging.error("Sparse Model keys have no intersection!")
return False
test_key_emb_dict = construct_key_embedding_dict(
test_key, test_meta_list, test_table_data
)
golden_key_emb_dict = construct_key_embedding_dict(
golden_key, golden_meta_list, golden_table_data
)
if not check_key_emb_equal(
key_intersection, test_key_emb_dict, golden_key_emb_dict
):
return False
return True
def check_basic_meta_equal(
test_meta_list: list,
golden_meta_list: list,
test_table_data: dict,
golden_table_data: dict,
) -> bool:
"""
Compare two sparse model meta items like key. embedding...
"""
logging.debug(
"Test Data Meta: %s Golden Data Meta: %s.", test_meta_list, golden_meta_list
)
if test_meta_list != golden_meta_list:
logging.error(
"Test Data meta: %s Golden Data meta: %s.Comparision meta not match! Please check your input sparse model.",
test_meta_list,
golden_meta_list,
)
return False
for meta_name in golden_meta_list:
test_meta_data = test_table_data[meta_name]
golden_meta_data = golden_table_data[meta_name]
logging.debug("++++++++++++ %s ++++++++++++", meta_name)
for meta_basic_content in META_BASIC_CONTENT_LIST:
if meta_basic_content == META_LENGTH:
logging.debug("[%s] no need to be the same, skip", META_LENGTH)
continue
test_meta_content = test_meta_data[meta_basic_content]
golden_meta_content = golden_meta_data[meta_basic_content]
logging.debug(
"[%s]Test Data tables: %s Golden Data tables: %s.",
meta_basic_content,
test_meta_content,
golden_meta_content,
)
if test_meta_content != golden_meta_content:
logging.error(
"%s %s does not match. Basic Meta must have the same basic content.",
meta_name,
meta_basic_content,
)
return False
return True
def parse_key_data_and_cmp(
test_key_data: list, golden_key_data: list, table_name: str
) -> list:
"""
Parse sparse modle as key_value and embedding pair and compare.
"""
test_key_set = set(test_key_data)
test_key_set_len = len(test_key_set)
golden_key_set = set(golden_key_data)
golden_key_set_len = len(golden_key_set)
key_intersection = list(test_key_set & golden_key_set)
intersection_len = len(key_intersection)
logging.info(
"Intersection part for table [%s] Test data: [%s/%s] Golden data: [%s/%s]",
table_name,
intersection_len,
test_key_set_len,
intersection_len,
golden_key_set_len,
)
return key_intersection
def construct_key_embedding_dict(
key_data: list, meta_name_list: list, table_data: dict
) -> dict:
"""
Parse sparse modle as key_value and embedding pair.
"""
key_embedding_dict = {}
temp_meta_name_list = copy.deepcopy(meta_name_list)
temp_meta_name_list.remove(META_NAME_KEY)
logging.info("Constructing key embedding dict, this may take a few minutes......")
for i, key_value in enumerate(key_data):
tmp_emb = [[] for _ in range(len(temp_meta_name_list))]
for j, meta_name in enumerate(temp_meta_name_list):
tmp_emb[j] = table_data[meta_name][META_DATA][i]
key_embedding_dict[key_value] = tmp_emb
logging.info("Constructing key embedding succeed.")
return key_embedding_dict
def check_key_emb_equal(
key_intersection: list, test_key_emb_dict: dict, golden_key_emb_dict: dict
) -> bool:
"""
Comparse sparse modle as key_value and embedding pair.
"""
logging.info("Comparing embeddings, this may take a few minutes......")
for key_value in key_intersection:
test_emb = test_key_emb_dict[key_value]
golden_emb = golden_key_emb_dict[key_value]
if len(test_emb) != len(golden_emb):
logging.error(
"Key %s Embedding shape not equal.Test shape: %s Golden shape: %s",
key_value,
len(test_emb),
len(golden_emb),
)
return False
for i, cur_test_emb in enumerate(test_emb):
if not np.allclose(cur_test_emb, golden_emb[i], rtol=SPARSE_ALLCLOSE_RTOL):
logging.error(
"KEY %s Embedding value not equal.Test Embedding:\n %s\nGolden Embedding:\n %s\n",
key_value,
cur_test_emb,
golden_emb[i],
)
return False
return True
def parse_table_info(tabel_path: str) -> Tuple[str, dict]:
"""
Parse sparse modle table info.
"""
meta_list = list_model_dirs(tabel_path)
table_name = os.path.basename(tabel_path)
table_meta_dict = {}
for meta_path in meta_list:
logging.debug("Meta path is %s", meta_path)
meta_name, meta_dict = parse_single_meta_data(meta_path)
logging.debug("%s, %s", meta_name, meta_dict.keys())
table_meta_dict[meta_name] = meta_dict
return table_name, table_meta_dict
def parse_single_meta_data(meta_path: str) -> Tuple[str, dict]:
"""
Parse single meta(like key or emb) data info.
"""
meta_name = os.path.basename(meta_path)
meta_attribute_path = os.path.join(meta_path, "slice.attribute")
meta_data_path = os.path.join(meta_path, "slice.data")
logging.debug(
"Meta name: %s. Meta attribute path: %s. Meta data path: %s.",
meta_name,
meta_attribute_path,
meta_data_path,
)
with tf.io.gfile.GFile(meta_attribute_path, "rb") as fin:
meta_attributes_file = fin.read()
try:
meta_attributes = np.fromstring(meta_attributes_file, dtype=np.int64)
except ValueError as err:
raise RuntimeError(
f"get attributes from file {meta_attribute_path} failed."
) from err
meta_length = meta_attributes[0]
meta_emb_dim = meta_attributes[1]
if len(meta_attributes) == 3:
meta_dtype = np.float32
meta_data = np.fromfile(meta_data_path, meta_dtype)
meta_data = meta_data.reshape(meta_length, meta_emb_dim)
else:
meta_dtype = np.int64
meta_data = np.fromfile(meta_data_path, meta_dtype)
meta_data_dict = {
META_LENGTH: meta_length,
META_EMB_DIM: meta_emb_dim,
META_DTYPE: meta_dtype,
META_DATA: meta_data,
}
return meta_name, meta_data_dict
def list_model_dirs(directory: str) -> list:
"""
List given directory dirs and files and get subdirs.
"""
current_sudirs_abspaths = []
for dirpath, dirs, files in os.walk(directory):
if dirpath != directory:
continue
if files:
raise ValueError(
f"find unexpected files{files}, saved model may have been tampered, please check."
)
for cur_dir in dirs:
current_sudirs_abspaths.append(os.path.join(dirpath, cur_dir))
return current_sudirs_abspaths