An Interactive Introduction to Model-Agnostic Meta-Learning

Exploring the world of model-agnostic meta-learning and its variants.

This page is part of a multi-part series on Model-Agnostic Meta-Learning. If you are already familiar with the topic, use the menu on the right side to jump straight to the part that is of interest for you. Otherwise, we suggest you start at the beginning.

iMAML: Implicit Gradients

To explain how Implicit Model-Agnostic Meta-Learning (iMAML) works, we will start with an observation: If we do many gradient steps in regular MAML, apart from an enormous computational burden, we face the issue that the model-parameters \( \phi \) depend less and less on the meta-parameter \( \theta \). If the parameters (\( \phi \) and \( \theta \)) are largely independent, placing \( \theta \) becomes more difficult, since its effect on \( \phi \) diminishes.

Regular MAML mitigates this by using only a few gradient steps. This early stopping is equivalent to a Bayesian prior. iMAML utilizes a more explicit regularization.

Let's take another look at problem formulation. The objective in meta-learning is to minimize the expected loss over a task distribution: $$ \min_\theta \mathbb E_\tau \left[ \mathcal L \left( \phi_\tau , \mathcal D ^ {test}_\tau \right)\right] $$ Here \( \phi_\tau \) is the task parameter that we acquire after solving the inner optimization problem, \( \mathcal D ^ {test}_\tau \) is the task-level test dataset.

In regular MAML \( \phi_\tau \) is obtained by Computing a single update (or a few) gradient descent steps: $$ \phi_\tau = U_\tau (\theta) = \theta - \alpha \nabla_\theta \mathcal L \left( \theta, \mathcal D^{train}_\theta \right) $$

Instead of using only a few update steps, we will use an arbitrary optimizer that optimizes the task parameter until it reaches a minimum, but instead of only minimizing the task loss we add a \( L_2 \) normalization term. term: $$ \phi_\tau = U^\ast_\tau (\theta) = \arg\min_\phi \left( \mathcal L \left( \phi, \mathcal D^{train}_\tau \right) + \frac{\lambda}{2} \| \phi - \theta \| ^ 2 \right) $$ Here, the objective is the almost the same, with the addition, that the moving to far away from the meta-parameter \( \theta \) results in a higher loss.

Play around with \( \theta \) and \( \lambda \) to get a feeling for the resulting loss space of the inner optimization objective. High \( \lambda \) will encourage the the algorithm to place the task-parameter close to the meta-parameter \( \theta \). \( \lambda = 0\) results in the original loss function.

Computing the Gradient

Now, in order to minimize the meta-objective, we again calculate the gradient: $$ \begin{align} \nabla_\theta\, \mathbb E_\tau \left[ \mathcal L \left( {\phi}_\tau , \mathcal D ^ {test}_\tau \right)\right] &= \mathbb E_\tau \left[ \nabla_{\theta}\, \mathcal L \left( {\phi}_\tau , \mathcal D ^ {test}_\tau \right)\right]\\ &= \mathbb E_\tau \left[ \nabla_{\phi_\tau} \, \mathcal L \left( \phi_\tau , \mathcal D ^ {test}_\tau\right) \cdot \frac{\mathrm d \phi_\tau}{\mathrm d \theta} \right] \end{align} $$ Calculating the first part \( \nabla_{\phi_\tau} \, \mathcal L \left( \phi_\tau , \mathcal D ^ {test}_\tau\right) \) can be done using back-propagation. This is the gradient of \(\phi_\tau\) at the parameter which was found by the optimizer. \( \frac{\mathrm d \phi_\tau}{\mathrm d \theta} \) is the part that MAML has its problems with. Depending on how complex the optimization algorithm is, this may be difficult to compute using back-propagation.

But, here comes the awesome part: Since we assume that our inner optimizer found a local minimum, we know that the gradient of the inner objective in regard to the task parameter is 0. This gives us the following equation: $$ \begin{align} \mathbf {0} &= \nabla_{\phi} \left( \mathcal L \left( \phi, \mathcal D^{train} \right) + \frac{\lambda}{2} \| \phi - \theta \| ^ 2 \right)\\ &= \nabla_{\phi} \mathcal L \left( \phi, \mathcal D^{train} \right) + \lambda \left( \phi - \theta \right) \end{align} $$ By rearranging we get: $$ \phi = \theta - \frac{1}{\lambda} \nabla_\phi \mathcal L \left( \phi, \mathcal D^{train} \right) $$

The red arrow denotes the gradient \( \nabla_\phi \mathcal L \left( \phi, \mathcal D^{train} \right) \). The gradient pulls the task parameters \( \phi \) towards the minimum of the task loss. You can imagine the green arrow as being the counter-force whicht pulls \( \phi \) toward the meta-parameter \( \theta \).

At the optimium, these forces need to cancel out, since moving in any direction won't improve the regularized loss. Hence, the gradient need to be orthogonal to the isocurve (white circle): moving along won't change the regularization term; since \( \phi \) is optimal for the joint term, the projection of the task-loss gradient onto the circle must be zero (moving some distance along the circle would improve the joint loss).

Now, we can calculate the Jacobian of the task-parameter \( \phi \) in regard to the meta-parameter \( \theta \): $$ \begin{align} \frac{\mathrm d \phi}{\mathrm d \theta} &= \frac{\mathrm d }{\mathrm d \theta} \left( \theta - \frac{1}{\lambda} \nabla_\phi \mathcal L \left( \phi, \mathcal D^{train} \right) \right)\\ &= \frac{\mathrm d \theta}{\mathrm d \theta} - \frac{1}{\lambda}\frac{ \mathrm d }{\mathrm d \theta} \nabla_\phi \mathcal L \left( \phi, \mathcal D^{train} \right)\\ %&= \frac{\mathrm d \theta}{\mathrm d \theta} - \frac{1}{\lambda}\frac{ \mathrm d }{\mathrm d \theta} %\frac{\mathrm d}{\mathrm d \phi} \mathcal L \left( %\phi, \mathcal D^{train} \right)\\ %&= \frac{\mathrm d \theta}{\mathrm d \theta} - \frac{1}{\lambda}\frac{ \mathrm d \phi}{\mathrm d \theta} %\frac{\mathrm d^2}{\mathrm d \phi ^2} \mathcal L \left( %\phi, \mathcal D^{train} \right)\\ &= I - \frac{1}{\lambda} \nabla^2_\phi \mathcal L \left( \phi, \mathcal D^{train} \right) \frac{\mathrm d \phi}{\mathrm d \theta} \end{align} $$ Here, to get from the 2nd to 3rd line, we need to apply the chain rule: \( \phi \) is a function of \( \theta \). Hence, we need to calculate the outer derivative (which results in the Hessian) and the inner derivative which is exactly what we want to calculate (also the total derivative \( \frac{\mathrm d \phi}{\mathrm d \theta} \)). Finally, we arrive at the following equations (assuming the inverse exists): $$\begin{align} &&\left(I + \frac{1}{\lambda} \nabla^2_\phi \mathcal L \left( \phi, \mathcal D^{train} \right)\right)\frac{\mathrm d \phi}{\mathrm d \theta} = I\\ \Rightarrow&& \frac{\mathrm d \phi}{\mathrm d \theta} = (I + \frac{1}{\lambda} \nabla^2_\phi \mathcal L \left( \phi, \mathcal D^{train} \right))^{-1} \end{align}$$ Let that sink in for a moment: By assuming that our inner optimizer found an optimal solution for our inner objective, we can derive a closed-form solution for the total derivative \( \frac{\mathrm d \phi}{\mathrm d \theta} \). To calculate the meta-gradient, we just need to know the solution of the inner optimization problem without knowing the steps to get there!

The steps leading up to the optimal solution do not need to be stored, and we could even use an optimizer that does not use Gradient Descent and where we might not be able to back-propagate the gradient. Instead, the optimizer can be treated as a black box, and we only require the final solution.

If this implicit differentiation is interesting to you, take a look at the paper "Efficient and Modular Implicit Differentiation". In this paper, the authors offer a more general framework for computing gradients without needing to backpropagate through the unrolled forward propagation. Instead, they use an optimality condition- (in the iMAML case its given by the gradient of the inner loop objective - in order to calculate the gradient implicitly.

Welcome back to reality and its approximations

Up until now, we have been telling you the idealized version of iMAML. The reality is a little bit more complicated. In the idealized version, we are making two assumptions that are entirely justified:

  1. We do not usually calculate the optimum for each task on the regularized loss. Instead, we usually compute a (hopefully) good approximation.
  2. Numerical matrix inversion is not that easy. In reality, this takes time and may be subject to numerical errors.

But do not despair! Rajeswaran et al. got you covered. The authors realized that these assumptions would be problematic and offer an approach to mitigate this issue. In the following paragraphs, we want to briefly outline how to deal with these issues.

Let \( g \) the meta-gradient that we want to find. Then we know from the equations above that the following equations hold true: $$ \begin{align} &&g &= \frac{\mathrm d \phi}{\mathrm d \theta}\, \nabla_\phi \mathcal L \left( \phi, \mathcal D ^{train} \right)\\ &\Rightarrow& \left(I + \frac{1}{\lambda} \nabla^2_\phi \mathcal L \left( \phi, \mathcal D^{train} \right)\right)\, g &= \nabla_\phi \mathcal L \left( \phi, \mathcal D ^{train} \right) \end{align} $$

The second equation follows the basic form \( Ax = b \). Luckily, there are numerical ways to solve such equations — this is a better way than trying to invert the matrix. One way to solve this is to use an algorithm called "Conjugate Gradient" (short: CG). An explanation of how the algorithm works is outside the scope of this article, but you should know the following: If a solution exists, CG guarantees to find the solution in a small number of steps (depending on the dimensionality of the matrix). Additionally, we never really need to calculate the matrix. Instead, we only need to compute the product of the matrix with some vector \( v \). If you take a look at our implementation of iMAML, you will realize how this is beneficial.

Once we solved this equation (for \(g \) ), we already have the meta-gradient. In their paper, Rajeswaran et al. theoretically prove that using approximations is enough: the errors do not grow with these calculations. Additionally, they show that empirically, iMAML is competitive with the other methods we discussed.

Discussion

As you can see in the gradient above, iMAML requires the computation of a second-order derivative. The huge benefit is that this second-order derivative only needs to be calculated for the last point the optimizer arrived at. We do not need to pass the gradient information through the steps of gradient-descent.

While calculating the gradient is comparatively easy, iMAML requires an optimizer that finds a quasi-optimal solution. Rajeswaran et al. show that the gradient is still approximately correct as long as the solution provided by the inner-loop optimizer is approximately correct. Still, we need more gradient steps if we use SGD than in regular MAML, where even one step may suffice.

According to the same paper, iMAML produces better results than MAML, although not using more resources. Whereas iMAML requires more inner loop steps, MAML requires either more outer loops steps or expensive computation of a long back-propagation chain. Compared to first-order MAML (FOMAML) and REPTILE, the authors report better results on Omniglot (the little exercise on the introduction page is based on Omniglot) and Mini-ImageNet , two common few-shot classification datasets.

As Ferenc Huszár points out in his wonderful blog post on iMAML, iMAML does not consider the stochasticity of Stochastic Gradient-Descent: SGD may have non-zero probabilities of finding more than one task-level optimum, but iMAML will only derive the gradient in respect to actually found optimum.

If you are interested in this consideration, you may also want to take a look at the paper titled "Probabilistic Model-Agnostic Meta-Learning".

In the next part, we will spend some time comparing the MAML and its variants interactively. Better close some background tasks on your device, 'cuz it'll get computationally heavy 👩‍💻.