5

I understand why tf.abs is non-differentiable in principle (discontinuity at 0) but the same applies to tf.nn.relu yet, in case of this function gradient is simply set to 0 at 0. Why the same logic is not applied to tf.abs? Whenever I tried to use it in my custom loss implementation TF was throwing errors about missing gradients.

zedsdead
  • 53
  • 3

2 Answers2

4

By convention, the $\mathrm{ReLU}$ activation is treated as if it is differentiable at zero (e.g. in [1]). Therefore it makes sense for TensorFlow to adopt this convention for tf.nn.relu. As you've found, of course, it's not true in general that we treat the gradient of the absolute value function as zero in the same situation; it makes sense for it to be an explicit choice to use this trick, because it might not be what the code author intends in general.

In a way this is compatible with the Python philosophy that explicit is better than implicit. If you mean to use $\mathrm{ReLU}$, it's probably best to use tf.nn.relu if it is suitable for your use case.

[1] Vinod Nair and Geoffrey Hinton. Rectified Linear Units Improve Restricted Boltzmann Machines. ICML'10 (2010). URL.

htl
  • 1,000
  • 1
  • 4
  • 13
  • 1
    Thanks a lot. That makes sense. Do you know then what is the easiest way to convert tf.abs into differentiable operation with a gradient at 0 defined as 0? – zedsdead Feb 17 '21 at 22:25
  • @zedsdead Could you just substitute in `tf.nn.relu` in the place you're using `tf.abs`? – htl Feb 18 '21 at 09:45
  • 3
    How about `relu(x) + relu(-x)`? :) – NikoNyrh Feb 20 '21 at 09:47
  • @htl Well, they are not the same thing but NikoNyrh’s answer below is probably the easiest solution. – zedsdead Feb 23 '21 at 16:36
2

Creating custom gradient for tf.abs may solve the problem:

@tf.custom_gradient
def abs_with_grad(x):
  y = tf.abs(x);

  def grad(div): # Derivation intermediate value
    g = 1; # Use 1 to make the chain rule just skip abs
    return div*g;

  return y,grad;

Use 1 as above to skip thru' abs or, use the actual abs grad (Samual K):

g = tf.where(x<0, -1, 1) #now the gradient at 0 would be one. This way u dont have dead  weights.
# With/without:
g = tf.where(x==0, 0, g) #if you realy want the gradient 0 at 0 add  this.
Dee
  • 1,283
  • 1
  • 11
  • 35