from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import tensorflow as tf
from tensorflow.python.ops import math_ops
from tensorflow.python.training import gradient_descent
from mx_rec.optimizers.base import CustomizedOptimizer
from mx_rec.util.initialize import ConfigInitializer
from demo_logger import logger
def create_hash_optimizer(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescent"):
optimizer = CustomizedGradientDescentWithWeighDecay(learning_rate=learning_rate,
weight_decay=weight_decay,
use_locking=use_locking,
name=name)
ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer
return optimizer
class CustomizedGradientDescentWithWeighDecay(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer):
name_counter = defaultdict(int)
def __init__(self, learning_rate, weight_decay, use_locking=False, name="GradientDescent"):
self.optimizer_type = "gradient_descent_with_weight_decay"
self.weight_decay = weight_decay
super(CustomizedGradientDescentWithWeighDecay, self)._get_name(name=name)
super(CustomizedGradientDescentWithWeighDecay, self).__init__(
learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name
)
self._slot_num = 0
self._derivative = 1
def get_slot_init_values(self):
logger.info("no slot for gradient descent")
return []
def _apply_sparse_duplicate_indices(self, grad, var):
logger.debug(">>>> Enter _apply_sparse_duplicate_indices")
nd_indices = tf.expand_dims(grad.indices, 1)
logger.info(f"weigh_decay={self.weight_decay}")
if self.weight_decay is None:
nd_value = grad.values * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype)
else:
nd_value = (grad.values + math_ops.cast(self.weight_decay, var.dtype.base_dtype) *
tf.gather(var, grad.indices)) * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype)
var_update_op = tf.scatter_nd_add(var, nd_indices, -nd_value, use_locking=self._use_locking)
return var_update_op
def _apply_dense(self, grad, var):
logger.debug(">>>> Enter _apply_dense")
raise NotImplementedError("You are using a wrong type of variable.")