from typing import Union
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.training.optimizer import Optimizer, _TensorProcessor
from rec_sdk_common.validator.safe_checker import class_safe_check, float_safe_check
from mxrec.python.constants.constants import NumCheckValueMethod
def patch_for_update_op():
def _update_op(self, optimizer: Optimizer, g: tf.Tensor):
return optimizer._apply_sparse(g, self._v)
_TensorProcessor.update_op = _update_op
def check_optimizer_param_value(
value: Union[float, tf.Tensor],
name: str,
max_value: float,
min_value: float = 0.0,
check_value_method: str = NumCheckValueMethod.DEFAULT.value,
):
class_safe_check(name, value, (float, tf.Tensor))
if not isinstance(value, float):
return
float_safe_check(name, value, min_value=min_value, max_value=max_value, method=check_value_method)
def deduplicate_indexed_slices(grad: ops.IndexedSlices) -> ops.IndexedSlices:
u_local_ids, u_local_idx = tf.unique(grad.indices)
unique_local_grad = tf.compat.v1.unsorted_segment_sum(
data=grad.values,
segment_ids=u_local_idx,
num_segments=tf.shape(u_local_ids)[0],
name="unique_local_grad",
)
deduplicate_grad = ops.IndexedSlices(
values=unique_local_grad,
indices=u_local_ids,
dense_shape=grad.dense_shape,
)
return deduplicate_grad