I'm trying to understand this tutorial for Jax.
Here's an excerpt. It's for a neural net that is designed to classify MNIST images:
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
I don't understand why they would subtract a constant value from all the final predictions, given that the only thing that matters is their relative values.
For clarity, here's the loss function:
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)