Batch normalization (
BatchNorm) operates on the activations of a layer for each mini-batch. Consider a simple feedforward network, defined by chaining together modules: x^{(0)} \mapsto x^{(1)} \mapsto x^{(2)} \mapsto \cdots where each network module can be a linear transform, a nonlinear activation function, a convolution, etc. x^{(0)} is the input vector, x^{(1)} is the output vector from the first module, etc. BatchNorm is a module that can be inserted at any point in the feedforward network. For example, suppose it is inserted just after x^{(l)}, then the network would operate accordingly: \cdots \mapsto x^{(l)} \mapsto \mathrm{BN}(x^{(l)}) \mapsto x^{(l+1)} \mapsto \cdots The BatchNorm module does not operate over individual inputs. Instead, it must operate over one batch of inputs at a time. Concretely, suppose we have a batch of inputs x^{(0)}_{(1)}, x^{(0)}_{(2)}, \dots, x^{(0)}_{(B)}, fed all at once into the network. We would obtain in the middle of the network some vectors: x^{(l)}_{(1)}, x^{(l)}_{(2)}, \dots, x^{(l)}_{(B)} The BatchNorm module computes the coordinate-wise mean and variance of these vectors: \begin{aligned} \mu^{(l)}_i &= \frac 1B \sum_{b=1}^B x^{(l)}_{(b), i} \\ (\sigma^{(l)}_i)^2 &= \frac{1}{B} \sum_{b=1}^B (x_{(b),i}^{(l)} - \mu_i^{(l)})^2 \end{aligned} where i indexes the coordinates of the vectors, and b indexes the elements of the batch. In other words, we are considering the i-th coordinate of each vector in the batch, and computing the mean and variance of these numbers. It then normalizes each coordinate to have zero mean and unit variance: \hat{x}^{(l)}_{(b), i} = \frac{x^{(l)}_{(b), i} - \mu^{(l)}_i}{\sqrt{(\sigma^{(l)}_i)^2 + \epsilon}} The \epsilon is a small positive constant such as 10^{-9} added to the variance for numerical stability, to avoid
division by zero. Finally, it applies a linear transformation: y^{(l)}_{(b), i} = \gamma_i \hat{x}^{(l)}_{(b), i} + \beta_i Here, \gamma and \beta are parameters inside the BatchNorm module. They are learnable parameters, typically trained by
gradient descent. The following is a
Python implementation of BatchNorm: import numpy as np def batchnorm(x, gamma, beta, epsilon=1e-9): # Mean and variance of each feature mu = np.mean(x, axis=0) # shape (N,) var = np.var(x, axis=0) # shape (N,) # Normalize the activations x_hat = (x - mu) / np.sqrt(var + epsilon) # shape (B, N) # Apply the linear transform y = gamma * x_hat + beta # shape (B, N) return y
Interpretation \gamma and \beta allow the network to learn to undo the normalization, if this is beneficial. BatchNorm can be interpreted as removing the purely linear transformations, so that its layers focus solely on modelling the nonlinear aspects of data, which may be beneficial, as a neural network can always be augmented with a linear transformation layer on top. and detractors.
Special cases The original paper Concretely, suppose we have a 2-dimensional convolutional layer defined by: x^{(l)}_{h, w, c} = \sum_{h', w', c'} K^{(l)}_{h'-h, w'-w, c, c'} x_{h', w', c'}^{(l-1)} + b^{(l)}_c where: • x^{(l)}_{h, w, c} is the activation of the neuron at position (h, w) in the c-th channel of the l-th layer. • K^{(l)}_{\Delta h, \Delta w, c, c'} is a kernel tensor. Each channel c corresponds to a kernel K^{(l)}_{h'-h, w'-w, c, c'}, with indices \Delta h, \Delta w, c'. • b^{(l)}_c is the bias term for the c-th channel of the l-th layer. In order to preserve the translational invariance, BatchNorm treats all outputs from the same kernel in the same batch as more data in a batch. That is, it is applied once per
kernel c (equivalently, once per channel c), not per
activation x^{(l+1)}_{h, w, c}: \begin{aligned} \mu^{(l)}_c &= \frac{1}{BHW} \sum_{b=1}^B \sum_{h=1}^H \sum_{w=1}^W x^{(l)}_{(b), h, w, c} \\ (\sigma^{(l)}_c)^2 &= \frac{1}{BHW} \sum_{b=1}^B \sum_{h=1}^H \sum_{w=1}^W (x_{(b), h, w, c}^{(l)} - \mu_c^{(l)})^2 \end{aligned} where B is the batch size, H is the height of the feature map, and W is the width of the feature map. That is, even though there are only B data points in a batch, all BHW outputs from the kernel in this batch are treated equally. Let the hidden state of the l-th layer at time t be h_t^{(l)}. The standard RNN, without normalization, satisfiesh^{(l)}_t = \phi(W^{(l)} h_t^{l-1} + U^{(l)} h_{t-1}^{l} + b^{(l)}) where W^{(l)}, U^{(l)}, b^{(l)} are weights and biases, and \phi is the activation function. Applying BatchNorm, this becomesh^{(l)}_t = \phi(\mathrm{BN}(W^{(l)} h_t^{l-1}) + U^{(l)} h_{t-1}^{l}) There are two possible ways to define what a "batch" is in BatchNorm for RNNs:
frame-wise and
sequence-wise. Concretely, consider applying an RNN to process a batch of sentences. Let h_{b, t}^{(l)} be the hidden state of the l-th layer for the t-th token of the b-th input sentence. Then frame-wise BatchNorm means normalizing over b: \begin{aligned} \mu_t^{(l)} &= \frac{1}{B} \sum_{b=1}^B h_{i,t}^{(l)} \\ (\sigma_t^{(l)})^2 &= \frac{1}{B} \sum_{b=1}^B (h_t^{(l)} - \mu_t^{(l)})^2 \end{aligned} and sequence-wise means normalizing over (b, t): \begin{aligned} \mu^{(l)} &= \frac{1}{BT} \sum_{b=1}^B\sum_{t=1}^T h_{i,t}^{(l)} \\ (\sigma^{(l)})^2 &= \frac{1}{BT} \sum_{b=1}^B\sum_{t=1}^T (h_t^{(l)} - \mu^{(l)})^2 \end{aligned} Frame-wise BatchNorm is suited for causal tasks such as next-character prediction, where future frames are unavailable, forcing normalization per frame. Sequence-wise BatchNorm is suited for tasks such as speech recognition, where the entire sequences are available, but with variable lengths. In a batch, the smaller sequences are padded with zeroes to match the size of the longest sequence of the batch. In such setups, frame-wise is not recommended, because the number of unpadded frames decreases along the time axis, leading to increasingly poorer statistics estimates.
Improvements BatchNorm has been very popular and there were many attempted improvements. Some examples include: • ghost batching: randomly partition a batch into sub-batches and perform BatchNorm separately on each; • weight decay on \gamma and \beta; • and combining BatchNorm with GroupNorm. A particular problem with BatchNorm is that during training, the mean and variance are calculated on the fly for each batch (usually as an
exponential moving average), but during inference, the mean and variance were frozen from those calculated during training. This train-test disparity degrades performance. The disparity can be decreased by simulating the moving average during inference: == Layer normalization ==