Accurately Computing the Softmax Function

The softmax function takes as input an n-vector x and returns a vector g(x) with elements

g_j(x) = \displaystyle\frac{\mathrm{e}^{x_j}}{\sum_{i=1}^n \mathrm{e}^{x_i}}, \quad j=1\colon n,

The elements of g are all between 0 and 1 and they sum to 1, so g can be regarded as a vector of probabilities. Softmax is a key function in machine learning algorithms.

Softmax is the gradient vector of the log-sum-exp function

f(x) = \displaystyle\log \sum_{i=1}^n \mathrm{e}^{x_i}.

This function is an approximation to the largest element, x_{\max} = \max_i x_i of the vector x, as it lies between x_{\max} and x_{\max} + \log n.

A problem with numerical evaluation of log-sum-exp and softmax is that overflow is likely even for quite modest values of x_i because of the exponentials, even though g(x) cannot overflow and f(x) is very unlikely to do so.

A standard solution it to incorporate a shift, a, and use the formulas

f(x) = a + \displaystyle\log \sum_{i=1}^n \mathrm{e}^{x_i-a}, \hspace*{4.5cm}(1)


g_j(x) = \displaystyle\frac{\mathrm{e}^{x_j-a}}{\sum_{i=1}^n \mathrm{e}^{x_i-a}}, \quad j=1\colon n, \hspace*{3.3cm}(2)

where a is usually set to x_{\max}.

Another formula for softmax is obtained by moving the denominator into the numerator:

g_j(x) = \exp\left(x_j - a - \log\displaystyle\sum_{i=1}^n\mathrm{e}^{x_i -a}\right). \hspace*{2cm}(3)

This formulas is used in various codes, including in the SciPy 1.4.1 function softmax.

How accurate are these formulas when evaluated in floating-point arithmetic? To my knowledge, this question has not been addressed in the literature, but it is particularly important given the growing use of low precision arithmetic in machine learning. Two questions arise. First, is there any difference between the accuracy of the formulas (2) and (3) for g_j(x)? Second, in (1) and (3), a is added to a nonnegative log term, so when a = x_{\max} is negative can there be damaging subtractive cancellation?

In a recent EPrint with Pierre Blanchard and Des Higham I have investigated these questions using rounding error analysis and analysis of the conditioning of the log-sum-exp and softmax problems. In a nutshell, our findings are that while cancellation can happen, it is not a problem: the shifted formulas (1) and (2) can be safely used.

However, the alternative softmax formula (3) is not recommended, as its rounding error bounds are larger than for (2) and we have found it to produce larger errors in practice.

Here is an example from training an artificial neural network using the MATLAB Deep Learning Toolbox. The network is trained to classify handwritten digits from the widely used MNIST data set. The following figure shows the sum of the computed elements of the softmax vector g(x) for 2000 vectors extracted from the training data, where g(x) was computed in IEEE half precision arithmetic. The sum should be 1. The red circles are for formula (2) and the blue crosses are for the division-free formula (3). Clearly, (2) gives a better approximation to a vector of probabilities (in the sense of respecting the constraint that probabilities sum to unity); the actual errors in each vector component are also smaller for (2).