An Interactive Introduction to Model-Agnostic Meta-Learning

Exploring the world of model-agnostic meta-learning and its variants.

This page is part of a multi-part series on Model-Agnostic Meta-Learning. If you are already familiar with the topic, use the menu on the right side to jump straight to the part that is of interest for you. Otherwise, we suggest you start at the beginning.

Why differentiating through an optimizer is actually as complicated as it sounds

In this section, we want to gain some understanding of what it actually means to differentiate through \(U\), the optimizer. Recall the gradient that MAML requires us to compute: \[ \nabla_\theta\, \mathcal{L}(\theta) = \mathbb{E}_\tau \left[\, \nabla_\theta \mathcal{L}_{\tau, \text{test}}(\phi) \,\right] .\] Expanding the inner term by applying the chain rule yields \[ \nabla_\theta \mathcal{L}_{\tau, \text{test}}(\phi) = \nabla_\theta \mathcal{L}_{\tau, \text{test}}(U_{\tau}(\theta)) = \nabla_{U_{\tau}(\theta)} \mathcal{L}_{\tau, \text{test}} \nabla_\theta U_{\tau}(\theta) .\] Here, \( \nabla_{U_{\tau}(\theta)}\, \mathcal{L}_{\tau, \text{test}} \) represents the gradient of the loss of task \(\tau\) by the optimized parameter \(\phi\) and \(\nabla_\theta U_{\tau}(\theta)\) is a gradient through an optimization algorithm.

Even if we assume that the optimizer takes only one gradient descent step, this term becomes \[ \nabla_\theta U_{\tau}(\theta) = \nabla_\theta (\theta - \alpha \nabla_\theta \mathcal{L}_{\tau, \text{train}}(\theta) ) = I - \alpha \nabla^2_\theta \mathcal{L}_{\tau, \text{train}}(\theta). \] Hence, MAML requires us to compute second derivatives in order to optimize \(\theta\), which is computationally inefficient, especially in high-dimensional problems (such as learning neural nets).

If we don't restrict ourselves to making just a single gradient update, the term becomes slightly more complicated. Let \(k\) be the number of updates we make and \( \phi^1 \) describe the \(i\)th step of gradient descent (\( \phi^0 = \theta\)), then: $$\begin{align} \nabla_\theta \phi^k &= \frac{\mathrm d\phi^k}{\mathrm d\theta} = \prod_{i=1}^{k} \frac{\mathrm d\phi^i}{\mathrm d\phi^{k-1} }\\ &= \prod^{k}_{i=1} \left( I - \alpha \nabla^2_{\phi^{i-1}}\mathcal L_{\tau,\text{train}}(\phi^{i-1}) \right) \end{align}$$

There are two methods of calculating the meta-gradient. Both methods have their unique limitations which we will discuss shortly.

Calculating the jacobians for every update step

The most straightforward approach is to calculat the jacobian in each update step. We can just keep track of the current product and multiply each steps jacobian with the accumulated jacobian matrix. Once we have completed all the update steps, we also have the jacobian of the optimizer ready.

The big issue with jacobian matrizes is that they are huge! For deep neural networks it is not uncommon to have paramters in the magnitude of millions. Even though we may not need to store the complete square matrix and can potentially sparsify it, the resulting number of required entries for a large model will still be enormous. This places a huge burden on the computational and memory resources. One advantage is, that the required memory only depends on the model size, but not on the number of update steps.

To find out how one can get around calculating any second order derivatives, take a look at FOMAML and REPTILE.

Calculating the hessian vector product directly

There is one way which circumvents the need to calculate and store a jacobian matrix. If take a step back and look at what we ultimately want to calculate, we recognize that we are not really interested in the jacobian, but in a product of the jacobian (first) with a gradient vector (second factor): \[ \nabla_\theta \mathcal{L}_{\tau, \text{test}}(\phi) = \left( \frac{\mathrm d U_{\tau}(\theta)}{\mathrm d\theta} \right) \left( \nabla_{U_{\tau}(\theta)} \mathcal{L}_{\tau, \text{test}} \right) .\]

The fact that we are not looking for the matrix itself, but only the product can be cleverly exploited. One key ingredient is a function which we will call \( \mathrm {sg} \) for "stop gradient". This function is actally very simple (in the implementation as well as its properties): It returns the same value which we input, but it stops the flow of the gradient during backpropagation. In tensorflow this function is implemented in tf.stop_gradient() while in py-torch one can use the .detach() function on a tensor to produce the same effect. This property can be described as: \[ \mathrm {sg}(y) \approx y\quad\land\quad \frac{\mathrm d\, \mathrm {sg}(y)}{\mathrm dx} = 0 .\] This allows us to make the following changes to our calculations (adding the second summand which evaluates to zero to then apply the product rule): \[ \begin{align} \nabla_\theta \mathcal{L}_{\tau, \text{test}}(\phi) &= \left( \frac{\mathrm d \phi}{\mathrm d\theta} \right) \mathrm{sg} \left( \nabla_{\phi} \mathcal{L}_{\tau, \text{test}} \right)\\ &= \left( \frac{\mathrm d \phi}{\mathrm d\theta} \right) \mathrm{sg} \left( \nabla_{\phi} \mathcal{L}_{\tau, \text{test}} \right) + \left( \frac{\mathrm d\, \mathrm{sg} \left( \nabla_{\phi} \mathcal{L}_{\tau, \text{test}} \right)}{\mathrm d\theta} \right) \phi \\ &= \frac{\mathrm {d}}{\mathrm {d} \theta} \left( \phi\, \cdot \mathrm{sg} \left(\nabla_{\phi} \mathcal{L}_{\tau, \text{test}}\right) \right) . \end{align}\]

What we essentially do, is to store the computational graph of our optimization procedure and then backpropagate the dot product of the gradient on the test loss (treating it as a constant) and the task parameter which the optimizer produced. By subtracting this derivative (which is the meta gradient) from the current meta-parameter (i.e., doing gradient descent), we move the task-parameter in a direction to better align it with the gradient on the test loss. If the two vectors (test loss gradient and task-parameter) are disaligned (i.e., the task-parameter is aligned with the negative test loss gradient) then moving into the the direction of the negative gradient will further minimize the dot product. Hence, minimizing the dot product and minimizing the test loss is actually equivalent when doing gradient descent.

While this approach does not require to store or calculate jacobian matrizes, it also comes with a downside: Calculating the derivative of the dot product requires backpropagating thorugh the whole optimization chain. Since we need to backpropagate through it at the end, we need to store the complete computaional graph and can't employ an iterative calculation as with the jacobian matrizes. The memory burden of this increases linearly with the nubmer of updates the optimizer takes. It should also be mentioned that while this approach does not require the calculation of jacobians, it does require second order derivatives.

To find out how one might remove the need to store the complete computational graph of the optimization prodedure, take a look at iMAML.

We will now spend some time on introducing the three most prominent solutions to this problem: FOMAML, REPTILE, and iMAML, as well as compare them to MAML and each other.