import hashlib
import json
import os
import pickle
import traceback
import inspect
from functools import wraps
import numpy as np
import torch
def serialize_param(param):
if isinstance(param, (list, tuple)):
return ''.join(serialize_param(item) for item in param)
elif isinstance(param, dict):
items = sorted(param.items())
return ''.join(serialize_param(v) for k, v in items)
else:
return hash_object(param) + "_"
def hash_object(obj):
try:
obj_bytes = pickle.dumps(obj)
hasher = hashlib.sha256()
hasher.update(obj_bytes)
hex_digest = hasher.hexdigest()[:10]
except Exception as e:
hex_digest = ""
return hex_digest
def save_data(data, save_path, case_name):
file_names = []
if isinstance(data, tuple) or isinstance(data, list):
for i, result in enumerate(data):
if isinstance(result, np.ndarray):
filename_i = case_name + str(i) + ".npy"
file_save_path = os.path.join(save_path, filename_i)
np.save(file_save_path, result)
elif isinstance(result, torch.Tensor):
filename_i = case_name + str(i) + ".pth"
file_save_path = os.path.join(save_path, filename_i)
torch.save(result, file_save_path)
elif isinstance(result, (int, float, list, tuple)):
filename_i = case_name + str(i) + ".json"
file_save_path = os.path.join(save_path, filename_i)
result_ = {}
result_["result"] = result
with open(file_save_path, 'w') as json_file:
json.dump(result_, json_file)
else:
raise ValueError(f"Save cache data failed, return data type should be np.ndarray, torch.Tensor, int or float, but got {type(result)}")
file_names.append(filename_i)
elif isinstance(data, np.ndarray):
filename = case_name + ".npy"
file_save_path = os.path.join(save_path, filename)
np.save(file_save_path, data)
file_names.append(filename)
elif isinstance(data, torch.Tensor):
filename = case_name + ".pth"
file_save_path = os.path.join(save_path, filename)
torch.save(data, file_save_path)
file_names.append(filename)
elif isinstance(data, (int, float, list, tuple)):
filename = case_name + ".json"
file_save_path = os.path.join(save_path, filename)
result_ = {}
result_["result"] = data
with open(file_save_path, 'w') as json_file:
json.dump(result_, json_file)
file_names.append(filename)
else:
raise ValueError(f"Save cache data failed, return data type should be np.ndarray, torch.Tensor, int or float, but got {type(data)}")
return file_names
def load_data(save_path, file_names):
results = []
for file_name in file_names:
file_path = os.path.join(save_path, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Load cache data failed. {file_path} not exists.")
if file_name.endswith(".npy"):
result = np.load(file_path)
results.append(result)
elif file_name.endswith(".pth"):
result = torch.load(file_path)
results.append(result)
elif file_name.endswith(".json"):
with open(file_path, 'r') as json_file:
result = json.load(json_file)["result"]
results.append(result)
return results
def golden_data_cache(ut_name, save_path=None, refresh_data=False):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
case_name = ''
for arg in args:
case_name += (serialize_param(arg))
for k, v in kwargs.items():
case_name += (serialize_param(v))
save_path_ = save_path
if save_path_ is None:
if os.getenv('MXDRIVING_CACHE_PATH', None) is not None:
save_path_ = os.getenv('MXDRIVING_CACHE_PATH', None)
else:
current_file_path = os.path.abspath(__file__)
save_path_ = os.path.dirname(current_file_path)
ut_name_ = os.path.basename(ut_name)
ut_name_ = os.path.splitext(ut_name_)[0]
save_path_ = os.path.join(save_path_, "data_cache", ut_name_, func.__name__)
cache_data_path = os.path.join(save_path_, case_name + ".json")
if not os.path.exists(cache_data_path) or refresh_data:
if not os.path.exists(save_path_):
os.makedirs(save_path_)
cache_data_names = {}
results = func(*args, **kwargs)
try:
file_names = save_data(results, save_path_, case_name)
if len(file_names) > 0:
cache_data_names[case_name] = file_names
with open(cache_data_path, 'w') as f:
json.dump(cache_data_names, f)
print(f"Cache data saved in {save_path_}.")
except Exception as e:
print("Failed to save cache.")
traceback.print_exc()
else:
with open(cache_data_path, 'r') as file:
cache_data_names = json.load(file)
file_names = cache_data_names[case_name]
try:
results = load_data(save_path_, file_names)
if len(results) == 1:
results = results[0]
else:
results = tuple(results)
print(f"Load cache data from {save_path_}.")
except Exception as e:
results = func(*args, **kwargs)
print("Failed to load cache, using golden function to generate data.")
traceback.print_exc()
return results
return wrapper
return decorator