from typing import List
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.module.operator.operator_database import OperatorHistory
try:
from waas_sdk.waas_client import WaasClient
from waas_sdk.api.krb_options import KrbOptions
from waas_sdk.api.tls_options import TlsOptions
from waas_sdk.api.data_options import DataOptions
except ImportError:
WaasClient = None
KrbOptions = None
TlsOptions = None
DataOptions = None
class WaasDataBase(object):
def __init__(self, ip_address: str, ip_port: int):
self.ip_address = ip_address
self.ip_port = ip_port
self.waas_client = WaasClient()
self.krb_options = KrbOptions()
self.tls_options = TlsOptions()
self.data_option = DataOptions()
self.krb_options.set_enable(False)
self.tls_options.set_enable(False)
self.waas_client.set_krb(self.krb_options)
self.waas_client.set_tls(self.tls_options)
self.connection = True
try:
self.waas_client.connect(ip_address, ip_port, "AutoTuning")
self.data_option.set_request_timeout(60)
self.data_client = self.waas_client.get_kv_data_client(self.data_option)
except Exception as e:
self.connection = False
self.keys = []
self.values = []
self.attributes_set = []
self.attributes_exclusive_set = []
self.key_prefix = ""
self._logger = get_logger('WaasDataBase')
def insert_data(self, data_key: List, data_value: List, batch_size=100):
if len(data_key) != len(data_value):
raise ValueError("The length of data_key and data_value must be the same.")
self.keys = data_key
self.values = data_value
total_items = len(data_key)
for index in range(0, total_items, batch_size):
end_index = min(index + batch_size, total_items)
key_batch = self.keys[index:index + end_index]
value_batch = self.values[index:index + end_index]
self.update_data(key_batch, value_batch)
def update_data(self, key, value):
batch_length = len(key)
for index in range(batch_length):
exist_value = self.get_data(key[index])
update_key, update_value = key[index], value[index]
if exist_value:
exist_operator = self.unmerge_get_attributes(key[index], exist_value)
new_operator = self.unmerge_get_attributes(key[index], value[index])
duration = (float(exist_operator['duration']) + float(new_operator['duration'])) / 2
new_operator['duration'] = str(duration)
update_key, update_value = self.merge_insert_attributes_dict(new_operator)
self.data_client.put(update_key, update_value)
def get_data(self, key):
temp_value = self.data_client.get(key)
return temp_value
def get_all_data(self, keys):
self.key_prefix = keys
temp_key, temp_value = [], []
self.data_client.get_all(self.key_prefix, temp_key, temp_value)
return temp_key, temp_value
def delete_data(self, data_key: List):
for item in data_key:
self.data_client.delete_all(item)
def convert_level_db_format(self, operators):
insert_list = {'key': [], 'value': []}
for operator in operators:
insert_key, insert_value = self.merge_insert_attributes(operator)
insert_list['key'].append(insert_key)
insert_list['value'].append(insert_value)
return insert_list
def merge_insert_attributes(self, operator):
selected_values = []
remaining_values = []
for attr in self.attributes_set:
try:
value = getattr(operator, attr)
selected_values.append(str(value))
except AttributeError:
self._logger.warning(f"{attr} is not in operator object")
selected_values.append("")
for attr in self.attributes_exclusive_set:
value = getattr(operator, attr)
remaining_values.append(str(value))
separator = '-'
key = separator.join(selected_values)
value = separator.join(remaining_values)
return key, value
def merge_insert_attributes_dict(self, operator):
selected_values = []
remaining_values = []
for attr in self.attributes_set:
value = operator.get(attr, "")
selected_values.append(str(value))
for attr in self.attributes_exclusive_set:
value = operator.get(attr, "")
remaining_values.append(str(value))
separator = '-'
key = separator.join(selected_values)
value = separator.join(remaining_values)
return key, value
@staticmethod
def merge_operator_cal(operator, input_shape=None, output_shape=None):
class_name = type(operator).__name__
if class_name == 'DictShape':
name = operator.type
if not name:
name = operator.types
else:
name = operator.types
accelerator_core = operator.accelerator_core
search_key = [accelerator_core, name]
if input_shape:
search_key.append(input_shape)
if output_shape:
search_key.append(output_shape)
separator = '-'
key = separator.join(search_key)
return key
def unmerge_get_attributes(self, key, value):
separator = '-'
selected_values = key.split(separator)
key_attr_values = dict(zip(self.attributes_set, selected_values))
remaining_values = value.split(separator)
value_attr_values = dict(zip(self.attributes_exclusive_set, remaining_values))
attr_values = {**key_attr_values, **value_attr_values}
return attr_values
def restore_all_data(self, operator):
keys = self.merge_operator_cal(operator)
key_list, value_list = self.get_all_data(keys)
key_length = len(key_list)
operator_list = []
for index in range(key_length):
dict_operator = self.unmerge_get_attributes(key_list[index], value_list[index])
operators = self.restore_attributes_to_operator(
OperatorHistory(
types='',
accelerator_core='',
input_shape='',
output_shape='',
duration=0,
device='',
jit='',
cann='',
driver='',
dtype=''
),
dict_operator
)
operator_list.append(operators)
return operator_list
@staticmethod
def restore_attributes_to_operator(operator, attr_values):
for attr, value in attr_values.items():
if attr == 'duration':
value = float(value)
setattr(operator, attr, value)
return operator
def attribute_separator(self, operator, attributes_list=None):
if attributes_list is None:
attributes_list = ['accelerator_core', 'types', 'input_shape']
total_attributes = list(vars(operator).keys())
attributes_set = list(attributes_list)
attributes_exclusive_set = [attr for attr in total_attributes if attr not in attributes_set]
if '_sa_instance_state' in attributes_exclusive_set:
attributes_exclusive_set.remove('_sa_instance_state')
self.attributes_set = attributes_set
self.attributes_exclusive_set = attributes_exclusive_set