from typing import Tuple, Optional, List
import tensorflow as tf
import mxrec
from config import Config
def get_dense_and_sparse_optimizer(cfg: Config) -> Tuple[tf.compat.v1.train.Optimizer, tf.compat.v1.train.Optimizer]:
if cfg.deterministic:
dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate[0])
sparse_optimizer = mxrec.AdamWOptimizer(weight_decay=cfg.weight_decay, learning_rate=cfg.learning_rate[1])
else:
dense_optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=cfg.learning_rate[0])
sparse_optimizer = mxrec.AdamWOptimizer(weight_decay=cfg.weight_decay, learning_rate=cfg.learning_rate[1], epsilon=1e-6)
dense_optimizer = DenseLossScaleOptimizer(dense_optimizer, cfg.loss_scale)
sparse_optimizer = SparseLossScaleOptimizer(sparse_optimizer, cfg.loss_scale)
return dense_optimizer, sparse_optimizer
class SparseLossScaleOptimizer:
"""A custom optimizer that scales the loss before computing sparse gradients."""
def __init__(self, opt: tf.compat.v1.train.Optimizer, loss_scale: int):
if not isinstance(opt, tf.compat.v1.train.Optimizer):
raise ValueError("`opt` must be an instance of Optimizer, but got: %s" % type(opt))
self._optimizer = opt
self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32)
_scale_learning_rate(self._optimizer, loss_scale)
def compute_gradients(self, loss: tf.Tensor, var_list: Optional[List[tf.Variable]] = None) -> List[tf.Tensor]:
return tf.gradients(loss * self._loss_scale, var_list)
def apply_gradients(self, grads_and_vars: List[Tuple[tf.Tensor, tf.Variable]]) -> tf.Operation:
return self._optimizer.apply_gradients(grads_and_vars)
class DenseLossScaleOptimizer:
"""A custom optimizer that scales the loss before computing dense gradients."""
def __init__(self, opt: tf.compat.v1.train.Optimizer, loss_scale: int):
if not isinstance(opt, tf.compat.v1.train.Optimizer):
raise ValueError("`opt` must be an instance of Optimizer, but got: %s" % type(opt))
self._optimizer = opt
self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32)
_scale_learning_rate(self._optimizer, loss_scale)
def compute_gradients(self, loss: tf.Tensor, var_list: Optional[List[tf.Variable]] = None) -> List[tf.Tensor]:
return self._optimizer.compute_gradients(loss * self._loss_scale, var_list=var_list)
def apply_gradients(self, avg_grads: List[Tuple[tf.Tensor, tf.Variable]]) -> tf.Operation:
return self._optimizer.apply_gradients(avg_grads)
def _scale_learning_rate(opt: tf.compat.v1.train.Optimizer, loss_scale: float):
if loss_scale == 0:
raise ValueError("loss scale can not be zero")
if hasattr(opt, "_learning_rate"):
opt._learning_rate = opt._learning_rate / tf.convert_to_tensor(loss_scale, tf.float32)
elif hasattr(opt, "_lr"):
opt._lr = opt._lr / tf.convert_to_tensor(loss_scale, tf.float32)
else:
raise ValueError("`opt` should have a `_learning_rate` or `_lr` named field")