from typing import Any, Tuple, Dict, Union
import tensorflow as tf
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig
from rec_sdk_common.communication.hccl.hccl_info import get_rank_size, get_rank_id
from lr_scheduler import LearningRateScheduler
class Config:
def __init__(self, data_path: str, toml_config: Dict[str, Dict[str, Union[str, int, float]]]):
self.data_path = data_path
self.rank_id = get_rank_id()
self.rank_size = get_rank_size()
self.train_file_pattern = "train"
self.test_file_pattern = "test"
self.train_steps = toml_config["model"]["train_steps"]
self.train_interval = toml_config["model"]["train_interval"]
self.eval_steps = toml_config["model"]["eval_steps"]
self.prefetch_num = toml_config["model"]["prefetch_num"]
self.line_per_sample = toml_config["model"]["line_per_sample"]
self.loss_scale = toml_config["model"]["loss_scale"]
self.weight_decay = toml_config["model"]["weight_decay"]
self.batch_size = toml_config["model"]["batch_size"]
self.train_epoch = toml_config["model"]["train_epoch"]
self.test_epoch = toml_config["model"]["test_epoch"]
self.dev_vocab_size = toml_config["model"]["dev_vocab_size"] * self.rank_size
self.emb_dim = toml_config["model"]["emb_dim"]
self.deterministic = toml_config["model"]["deterministic"]
self.key_type = tf.int64
self.value_type = tf.float32
self.global_step = tf.Variable(0, trainable=False)
self.learning_rate = self._get_lr(toml_config)
def _get_lr(self, toml_config: Dict[str, Dict[str, Union[str, int, float]]]) -> Tuple[tf.Tensor, tf.Tensor]:
global_batch_size = self.batch_size * self.rank_size
lr_scheduler = LearningRateScheduler(
toml_config["model"]["base_lr_dense"],
toml_config["model"]["base_lr_sparse"],
int(toml_config["model"]["warmup_steps"] / global_batch_size),
int(toml_config["model"]["decay_start_step"] / global_batch_size),
int(toml_config["model"]["decay_steps"] / global_batch_size),
)
return lr_scheduler.calc(self.global_step)
def sess_config(
dump_data: bool = False, dump_path: str = "./dump_output", dump_steps: str = "0|1|2",
is_deterministic: bool = False
) -> Any:
session_config = tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=False)
session_config.gpu_options.allow_growth = True
custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add()
custom_op.name = "NpuOptimizer"
custom_op.parameter_map["mix_compile_mode"].b = False
custom_op.parameter_map["use_off_line"].b = True
custom_op.parameter_map["min_group_size"].b = True
custom_op.parameter_map["enable_data_pre_proc"].b = True
custom_op.parameter_map["iterations_per_loop"].i = 10
if is_deterministic:
custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("must_keep_origin_dtype")
custom_op.parameter_map["deterministic"].i = 1
else:
custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
custom_op.parameter_map["hcom_parallel"].b = False
custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_precision.ini")
custom_op.parameter_map["op_execute_timeout"].i = 2000
custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes(str(13 * 1024 * 1024 * 1024))
custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(18 * 1024 * 1024 * 1024))
custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("DNN_VM_AICPU:3,AIcoreEngine:3")
if dump_data:
custom_op.parameter_map["enable_dump"].b = True
custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path)
custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps)
custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all")
session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF
session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF
return session_config