Search

MAML: Model-Agnostic Meta -Learning for Fast Adaptation of Deep Networks

ICML, ‘17

Summary

MAML is a general and model-agnostic algorithm that can be directly applied to a model trained with gradient descent procedure.
MAML does not expand the number of learned parameters.
MAML does not place constraints on the model architecture.

Key words

Model agnostic
Fast adaptation
Optimization based approach
Learning good model parameters

Prelimiaries

Common Approaches of Meta-Learning and MAML
A few terminologies of meta-learning problems

1. Introduction

Goal of ideal artificial agent:
Learning and adapting quickly from only a few examples.
To do so, an agent must..
Integrate its prior experience with a small amount of new information.
Avoid overfitting to the new data.
→ Meta-learning has same goals.
MAML:
"The key idea of MAML is to train the model's initial parameters such that the model has maximal performance on a new task after the parameters have been updated through one or more gradient steps computed with a small amount of data from that new task."
Learning process of MAML:
MAML maximizes the sensitivity of the loss functions of new tasks.
Authors demonstrated the algorithm on three different model types.
Few-shot regression
Image classification
Reinforcement learning

2. Model-Agnostic Meta Learning

2.1. Meta-Learning Problem Set-Up

To apply MAML to a variety of learning problems, authors introduce a generic notion of a learning task:
T={L(x1,a1,...,xH,aH),q(x1),q(xt+1xt,at),H}\mathcal{T} = \{ \mathcal{L}(\mathbf{x}_1, \mathbf{a}_1, ..., \mathbf{x}_H, \mathbf{a}_H), q(\mathbf{x}_1), q(\mathbf{x}_{t+1}|\mathbf{x}_t, \mathbf{a}_t), H \}
Each task T\mathcal{T} consists of..
L\mathcal{L}: a loss function, might be misclassification loss or a cost function in a Markov decision process
q(x1)q(\mathbf{x}_1): a distribution over initial observations
q(xt+1xt,at)q(\mathbf{x}_{t+1}|\mathbf{x}_t , \mathbf{a}_t): a transition distribution
HH: an episode length(e.g. in i.i.d. supervised learning problems, the length H=1H = 1.)
Authors consider a distribution over tasks p(T)p(\mathcal{T})
Meta-training:
A new task Ti\mathcal{T}_i is sampled from p(T)p(\mathcal{T}).
The model is trained with only KK samples drawn from qiq_i.
Loss LTi\mathcal{L}_{\mathcal{T}_i} is calculated and feedbacked to model.
Model ff is tested on new samples from Ti\mathcal{T}_i.
The model ff is then improved by considering how the testtest error on new data from qiq_i changes with respect to the parameters.

2.2. A Model-Agnostic Meta-Learning Algorithm

Intuition: Some internal representations are more transferrable than others. How can we encourage the emergence of such general-purpose representations?
A model fθf_\theta has paramters θ\theta.
For each task Ti\mathcal{T}_i, fθf_\theta's parameters θ\theta become θi\theta_i'.
Algorithm
cf) Terminologies for below description(temporarily defined by JH Gu)
Divide tasks
1.
Separate tasks into meta-training task set({Titr}\{\mathcal{T}_i^{\text{tr}}\}) and meta-test task set({Titest}\{\mathcal{T}_i^{\text{test}}\}).
(We can think of {Titr}\{\mathcal{T}_i^{\text{tr}}\} as monthly tests(모의고사), and {Titest}\{\mathcal{T}_i^{\text{test}}\} as annual tests(수능))
2.
For each task, divide each samples into DTistudy\mathcal{D}_{\mathcal{T}_i}^{\text{study}}(task-specific samples for studying, also called as support set), DTicheck\mathcal{D}_{\mathcal{T_i}}^{\text{check}}(task-specific samples for checking, also called as query set)
(We can think of DTistudy\mathcal{D}_{\mathcal{T}_i}^{\text{study}} as 필수예제 in 수학의 정석, and DTicheck\mathcal{D}_{\mathcal{T}_i}^{\text{check}} as 연습문제 in 수학의 정석)
Meta-training using meta-training task set {Titr}\{\mathcal{T}_i^{\text{tr}}\}
Inner loop(task-specific KK-shot learning)
For each Ti\mathcal{T}_i in {Titr}\{\mathcal{T}_i^{\text{tr}}\}, a new parameter θi\theta_i' is created.
1.
Each θi\theta_i' is initialized as θ\theta.
2.
With task-specific samples for studying(DTitrstudy\mathcal{D}_{\mathcal{T}_i^{\text{tr}}}^{\text{study}}), each θi\theta_i' is updated by:
θi=θαθLTi(fθ)\theta_i' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)
Outer loop(meta-learning across tasks)
1.
With task-specific samples for checking(DTitrcheck\mathcal{D}_{\mathcal{T_i}^{\text{tr}}}^{\text{check}}), θ\theta is updated by:
θ=θβθTip(T)LTi(fθi)\theta = \theta - \beta \nabla_\theta \sum_{\mathcal{T}_i \sim p(\mathcal{T})}\mathcal{L}_{\mathcal{T}_i} (f_{\theta_i'})
cf) second-order derivative(Hessian) problem
Measure model performance using meta-test task set {Titest}\{\mathcal{T}_i^{\text{test}}\}
1.
For each Ti\mathcal{T}_i in {Titest}\{\mathcal{T}_i^{\text{test}}\}, adjust task-specific parameters with DTiteststudy\mathcal{D}_{\mathcal{T}_i^{\text{test}}}^{\text{study}}.
2.
Test the performance with DTitestcheck\mathcal{D}_{\mathcal{T_i}^{\text{test}}}^{\text{check}}.

3. Species of MAML

3.1. Supervised Regression and Classification

Algorithm
Formalizing supervised regression and classification
Horizon H=1H = 1
Drop the timestep subscript on xt\mathbf{x}_t (since model accepts a single input and produces a single output)
The task Ti\mathcal{T}_i generates KK i.i.d. observations x\mathbf{x} from qiq_i
Task loss is represented by the error between the model's output for x\mathbf{x} and the corresponding target values y\mathbf{y}.
Loss functions
MSE for regression
LTi(fϕ)=x(j),y(j)Tifϕ(x(j))y(j)22\mathcal{L}_{\mathcal{T}_i}(f_\phi) = \sum_{\mathbf{x}^{(j)}, \mathbf{y}^{(j)} \sim \mathcal{T}_i} \| f_\phi(\mathbf{x}^{(j)}) - \mathbf{y}^{(j)}\|^2_2
Cross entropy loss for discrete classification
LTi(fϕ)=x(j),y(j)Ti{y(j)logfϕ(x(j))(1y(j))log(1fϕ(x(j)))}\mathcal{L}_{\mathcal{T}_i}(f_\phi) = \sum_{\mathbf{x}^{(j)}, \mathbf{y}^{(j)} \sim \mathcal{T}_i} \big\{ \mathbf{y}^{(j)} \log f_\phi(\mathbf{x}^{(j)}) - (1-\mathbf{y}^{(j)})\log(1-f_\phi(\mathbf{x}^{(j)}))\big\}

3.2. Reinforcement Learning

Algorithm
Goal of MAML in RL:
Quickly acquire a policy for a new test task using only a small amount of experience in the test setting.
Formalizing RL
Each RL task Ti\mathcal{T}_i contains..
Initial state distribution qi(x1)q_i(\mathbf{x}_1)
Transition distribution qi(xt+1xt,at)q_i(\mathbf{x}_{t+1}|\mathbf{x}_t, \mathbf{a}_t)
at\mathbf{a}_t: action
Loss LTi\mathcal{L}_{\mathcal{T}_i}, which corresponds to the negative reward function RR
Therefore, entire task is a Markov decision process(MDP) with horizon HH
The model being learned, fθf_\theta, is a policy that maps from states xt\mathbf{x}_t to a distribution over actions at\mathbf{a}_t at each timestep t{1,...,H}t \in \{ 1, ..., H\}
Loss function for task Ti\mathcal{T}_i and model fϕf_\phi:
LTi(fϕ)=Ext,atfϕ,qTi[t=1HRi(xt,at)]\mathcal{L}_{\mathcal{T}_i}(f_\phi) = -\mathbb{E}_{\mathbf{x}_t, \mathbf{a}_t \sim f_\phi, q_{\mathcal{T}_i}} \bigg [ \sum_{t=1}^H R_i(\mathbf{x}_t, \mathbf{a}_t) \bigg ]
Policy gradient method
Since the expected reward is generally not differentiable due to unknown dynamics, authors used policy gradient methods to estimate the gradient.
The policy gradient method is an on-policy algorithm
→ There are additional sampling procedures in step 5 and 8.

4. Comparison with related works

Comparison with other popular approaches
Training a meta-learner that learns how to update the parameters of the learner's model
ex) On the optimization of a synaptic learning rule(Bengio et al. 1992)
→ Requires additional parameters, while MAML does not.
Training to compare new examples in a learned metric space
ex) Siamese networks(Koch, 2015), recurrence with attention mechanisms(Vinyals et al. 2016)
→ Difficult to directly extend to our problems, such as reinforcement learning.
Training memory-augmented models
ex) Meta-learning with memory-augmented neural networks(Santoro et al. 2016)
The recurrent learner is trained to adapt to new tasks as it is rolled out.
→ Not really straightforward.

5. Experimental Evaluation

Three questions
1.
Can MAML enable fast learning of new tasks?
2.
Can MAML be used for meta-learning in multiple different domains?
3.
Can a model learned with MAML continue to improve with additional gradient updates and/or examples?

5.1. Regression

5.2. Classification

5.3. Reinforcement Learning

References