An Interactive Introduction to Model-Agnostic Meta-Learning

Exploring the world of model-agnostic meta-learning and its variants.

This page is part of a multi-part series on Model-Agnostic Meta-Learning. If you are already familiar with the topic, use the menu on the right side to jump straight to the part that is of interest for you. Otherwise, we suggest you start at the beginning.

How MAML works

Model-agnostic meta-learning (MAML) is a meta-learning approach to solve different tasks from simple regression to reinforcement learning but also few-shot learning. . To learn more about it, let us build an example from the ground up and then try to apply MAML. We will do this by alternating mathematical walk-throughs and interactive, as well as coding examples.

If you have applied machine learning before, you have probably already solved or attempted to solve a problem like the following: Learning a model to solve one specific task, for example, to classify cats from dogs or to teach an agent to find its way through a specific maze. In these settings, if we are able to define a loss \(\mathcal{L}_\tau\) for our task \(\tau\), which depends on the parameters \(\phi\) of a model, we can express our learning objective as \[ \underset{\phi}{\text{min}} \, \mathcal{L}_\tau (\phi) .\] We usually find the optimal \(\phi\) by progressively walking along the direction of the gradient of \(\mathcal{L}_\tau\) with respect to \(\phi\), i.e. \[ \phi \leftarrow \phi - \alpha \nabla_\phi \mathcal{L}_\tau (\phi) ,\] also known as gradient descent. \(\mathcal{L}_\tau\) usually also depends on some data and \(\alpha\) is a fixed learning rate, controlling the size of the steps we want to take.

Unfortunately, applying this to a regression or a few-shot setting (i.e., with a very small dataset), the above method is known to perform poorly on, e.g., neural networks, since there is simply too little data for too many parameters leading to overfitting. The key idea of MAML is to mitigate this problem by learning not only from the data regarding exactly our task but also from data of similar tasks. To incorporate this, we make an additional assumption, namely that \(\tau\) comes from some distribution of tasks \(p(\tau)\) and that we can sample freely from this distribution. Eventually we want to use the data available from the other tasks in the distribution to be able to converge to a specific task \(\tau_i \sim p(\tau)\), which we can express in terms of an expectation over the distribution. where \(\tau\) is now a random variable and \(\phi_\tau\) is a set of parameters for task \(\tau\). We may use different parameters for each task, use the same parameters for every task, or do something in between.

But there is one piece missing yet. We will not simply use the data from other tasks to find parameters that are optimal for all tasks, but keep the option to fine-tune our model, i.e. take additional optimizer steps on data from the new task \(\tau_i\). Afterwards we want to have converged to \(\tau_i\) and reuse the pre-fine-tune-version of the model for each new task. Thus, we can express our optimization objective as \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U_\tau(\theta)) ] ,\] where \(U_\tau\) is a optimization algorithm and maps \(\theta\) to a new parameter vector, being the result of fine-tuning \(\theta\) on data from task \(\tau\), using optimizer \(U_\tau\). For the rest of this article we assume \(U_\tau\) corresponds to performing gradient descent with a variable number of steps but don't let this limit your imagination of what algorithm \(U_\tau\) could be.

A word on terminology: In conventional machine learning settings we consider trainable parameters that are tied to our task. However, the \(\theta\) in the above objective is learned with respect to a variety of tasks. This together with the fact that it can further be regarded as the initialization of optimizer \(U_\tau\), lets us interpret \(\theta\) to be above task-level and thus acquires the status of a meta-parameter. Consequently, optimizing such a meta-parameter corresponds to meta-learning.

Having set the above objective, we are already half-way there. The only thing that is left is to find a feasible optimizer for \(\theta\). Before we jump into how MAML solves this problem, we are going to take a look at a simple baseline, which will help us to digest the setting a bit better and which leads us directly to MAML.

Part 1: A simple baseline

Recalling our optimization objective \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U_\tau(\theta)) ] ,\] the following approach mitigates dealing with \(U_\tau\), mostly by ignoring it exists, which would make the objective collapse to \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (\theta) ] ,\] i.e. the standard machine learning setting that we have already talked about. As a consequence we are now not operating on a few samples of task but have a whole distribution of tasks at our disposal and we can hence reliably solve the simplified objective with gradient descent.

Omitting \(U_\tau\), we expect the final \(\theta\) to be positioned such that fine-tuning it on some task \(\tau_i\) from the distribution converges to the optimal parameters for \(\tau_i\) with only a few samples. That might seem naive, considering that we did not reason about why \(U_\tau\) might be disregarded, but simply disregarded it. But on the other hand, one is not called a "simple baseline" for no reason.

Expectations are commonly approximated by an empirical mean over samples from the respective distribution, also known as Expected Risk Minimization (ERM), and precisily that is what we can do here as well: \[ \theta \leftarrow \theta - \alpha \nabla_\theta \sum_i \mathcal{L}_{\tau_i} (\theta) .\]

Finn et al., the authors of MAML, call this type of baseline the pretrained model: we can simply pretrain over all available data and defer the problem of dealing with \(U_\tau\). Now, we can make use of the pretrained model, by simply fine-tuning the final \(\theta\), the result of our pretraining, on a new task - which is exactly what we will do a bit further down!

Moving on, we will take a little detour and talk about some implementational aspects of the pretrained baseline. It will also serve us as a starting point to later implement MAML. Afterwards we a small interactive experiment prepared, where you can watch the pretrained model fail.

Implementing the Pretrained Model

If the above has gotten all too theoretical for you, take a look at the following gist. It contains a simplistic implementation of an update step for this pretrained model. It is implemented to emphasize that even if we differentiate between tasks when sampling the batch, the actual optimizer treats each sample the same.

The implementation is agnostic to the choice of optimizer. We use the Adam optimizer to be congruent with the original paper.

Pretrained Model on a Sinusoid Problem (Regression)

In the following figure, you can experiment with a pretrained model trained by a collection of sinusoid regression tasks. The task distribution works as follows: Each task is represented by an amplitude \(A\) and a phase \(\varphi\) and requires the prediction of sinusoid \(f\): \[ f(x) := A \sin(x + \varphi),\] where \(A, \varphi\) are sampled uniformly from some predefined range. Different parameters yield different functions, \(f_1\) and \(f_2\), with possibly completely different function values and gradients. Take, for example, the following two tasks: Tasks \(\tau_1, \tau_2\) are both regression tasks on sinusoids \(f_1(x) := \sin (x - \frac{\pi}{2})\) and \(f_2(x) := \sin (x + \frac{\pi}{2})\) respectively. These two tasks' function values give completely contradicting information, as \[ f_1(x) = - f_2(x). \]

Before fitting the model, what do you expect to happen based on the position and the number of samples provided? Feel free to also experiment with the different settings: distributing the samples equispaced or squeezing all of them to a small range of the x-axis.

Ouch! That does not seem to work that well. Maybe you have already guessed that it would have been too easy. Remember our interpretation of what happens when omitting \(U_\tau\)? We said that we expect the \(\theta\) that minimizes the simplified objective can be fine-tuned easily into any task from the distribution. But as it seems, a \(\theta\) this way is either impossible or at the very least, incredibly difficult.

Let us recall the original optimization objective, i.e., \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U_\tau(\theta)) ] .\] We can augment this notation by giving \(U_\tau\) a superscript, i.e. write \(U^{(m)}_\tau\), indicating that we perform \(m\) steps of gradient descent. Then we recover the simplified objective of the pretrained model by setting \(m = 0\), as \[ U^{(m)}_\tau(\theta) = \theta .\] We have already seen that for \(m =0\) the loss space with respect to some task samples becomes \[ \sum_i \mathcal{L}_\tau (\theta) ,\] i.e. a simple sum of loss spaces. The following figure explores this representation visually, by letting you control \(m\) to see how the resulting loss space changes. From now on we have to carefully distinguish task loss spaces which are defined by the individual \(\mathcal{L}_{\tau_1}\), ..., \(\mathcal{L}_{\tau_n}\) for tasks \(\tau_1\), ..., \(\tau_n\) and the accumulated loss space, defined by \[ \sum_i \mathcal{L}_\tau (U^{(m)}_\tau(\theta)) .\]

Now that we have established the problems arising from setting omitting fine-tuning function \(U_\tau\) from the optimization objective, we will turn to MAML. MAML does not disregard \(U_\tau\) but rather optimizes through it. In the next part we will see how that goes about.

Part 2: Model-Agnostic Meta-Learning

We will now study MAML in detail, trying to optimize the previously established few-shot learning objective \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U^{(m)}_\tau(\theta)) ] ,\] for \(m > 0\). In short MAML optimizes the same \(\theta\) as the pretrained model but in its optimization strategy it acknowledges the effect of fine-tuning function \(U_\tau\) on the accumulated loss space.

Outline of the Algorithm

Let us jump right in and take a look at the three main steps of the method, given a (current) meta-parameter \(\theta\) :

Note, that \(\mathcal{L}_{\tau_i, \text{train}}\) and \(\mathcal{L}_{\tau_i, \text{test}}\) are two instances of the same loss function on the same task \( \tau_i \) and corresponding training or test data from this task (though \( \tau_i \) changes while iterating over \(i \) ). The easiest way to obtain \(\phi_i\), is to do a single step of gradient descent (\( \phi_i \) will not be optimal but most likely better then \( \theta \) ): \[ \phi_i = \theta - \alpha \nabla_\theta \mathcal{L}_{\tau_i, \text{train}}(\theta).\] Further, updating \(\theta\) requires us to evaluate the gradient of the individual task losses on a set of test data. We obtain the gradient of the overall loss as follows: \[ \nabla_\theta \mathcal{L}(\theta) = \sum_{i} \nabla_\theta \mathcal{L}_{\tau_i, \text{test}}(\phi_i) .\] Note that \(\phi_i = U_{\tau_i}(\theta)\) depends on \(\theta\), which means that we have to take a gradient through the optimizer \(U\). We can then update \(\theta\) via gradient descent, using a new learning rate \(\beta\): \[ \theta' = \theta - \beta \nabla_\theta \mathcal{L}(\theta).\] And that 🥁... is more or less everything that comprises the original MAML algorithm.

Implementing the Algorithm

However, a machine learning algorithm is not very useful unless we can execute it on a computer. While implementing the pretrained model was more or less straightforward, implementing MAML requires some more thinking. Firstly, computing \(\phi_i\) is still straightforward; simply call the optimization algorithm of your choice (as long as it is gradient-based). However, how do we then take the gradient through that optimization algorithm? It is actually not that complicated. Almost every modern machine learning framework (e.g., TensorFlow ), can differentiate through nearly arbitrary python code. Hence, if we can express our optimizer in a python function, TensorFlow can differentiate through it.

Below you find a gist that implements a simplistic version of the MAML update step. The optimizer is encoded within the function fastWeights, but the function also directly applies an input tensor to the optimized weights. We did this mainly for simplicity, but if you are interested in a thorough reasoning about this design choice, you can read more about it in the comments under the gist.

Will it really work?

Before we study the MAML model on the sinusoid task distribution, let us spend some time on trying to see MAML in action. Consider the problem in the below figure. As already established, our few-shot optimization objective was to learn an optimal meta-parameter \(\theta\), from which we easily fine-tune on any task to converge with but a few respective samples. The figure considers a task distribution of two different tasks and lets you move \(\theta\) around to make sure, in the spirit of MAML, that a single-step-fine-tuning can result in nearly optimal parameters, for each task respectively. Optimizing \(\theta\) not directly on the tasks, as the pretrained model would, but to respect \(\theta\)'s role as an initialization to the fine-tuning algorithm is what makes MAML both elegant and effective.

Returning to Sinusoids

After studying the math behind the MAML objective, as well as its intuition and implementation, it is time to evaluate it on the sinusoid example. Hopefully, MAML will produce better results than the pretrained model. You will have the opportunity to repeat the above experiments on a model that has been trained with MAML in this figure. Try to compare the optimization behavior of both the pretrained model and MAML and evaluate for yourself whether you think the MAML-trained model has found a good meta-initialization parameter \(\theta\).

So as you were hopefully able to verify, MAML produces results that are way closer to the actual sinusoid, despite being exposed to at most five samples.

The rest of this article is dedicated to introducing interesting variants of MAML. The next page starts with a general discussion about the difficulty of obtaining the MAML-meta-gradient, which leads directly to FOMAML, a simple first-order version of MAML. A slightly different first-order approach, but still in the spirit of MAML, is Reptile, which obtains meta-knowledge without an explicit meta-gradient. Lastly, iMAML approximates the meta-update by creating a dependency between task-loss and meta-parameter \(\theta\) and thereby bypasses some of the computationally more expensive parts of the original MAML.