FO-MAML-Quadratic
unknown
python
4 years ago
2.1 kB
10
Indexable
import torch import matplotlib.pyplot as plt from copy import deepcopy # torch.manual_seed(0) # define dataset: quadratic functions: Support (first 9 points) and Query Sets (last 3 points) task1 = {} task1['x'] = [1,2,3,-3,-2,-1, 8, 9, -8, -9, 15, -15] task1['y'] = [4*x^2+3*x+2 for x in task1['x']] # y = 4x2+3x+2 task2 = {} task2['x'] = [1,2,3,-3,-2,-1, 8, 9, -8, -9, 15, -15] task2['y'] = [4*x^2-2*x+4 for x in task1['x']] # y = 4x2-2x+4 task3 = {} task3['x'] = [1,2,3,-3,-2,-1, 8, 9, -8, -9, 15, -15] task3['y'] = [4*x^2+8*x-12 for x in task1['x']] # y = 4x2+8x-12 alpha = 0.00003 beta = 0.0001 # define two models model_init = torch.nn.Sequential(torch.nn.Linear(1,5), torch.nn.ReLU(), torch.nn.Linear(5,1)) model_dash_es = [deepcopy(model_init), deepcopy(model_init), deepcopy(model_init)] loss_list = [] # for outer loop for i in range(5000): meta_loss = 0 grads_list = [] # for inner loop for j, task in enumerate([task1, task2, task3]): model_dash = model_dash_es[j] y_preds = model_init(torch.FloatTensor(task['x'][:9]).unsqueeze(1)) loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][:9]).unsqueeze(1)) model_init.zero_grad() model_dash.zero_grad() grads = torch.autograd.grad(loss, model_init.parameters()) # inner optimization: dash = init - beta*grad with torch.no_grad(): for param_init, param_dash, grad in zip(model_init.parameters(), model_dash.parameters(), grads): param_dash += -param_dash + param_init - beta*grad y_preds = model_dash(torch.FloatTensor(task['x'][9:]).unsqueeze(1)) query_loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][9:]).unsqueeze(1)) meta_loss += query_loss.item() grads = torch.autograd.grad(query_loss, model_dash.parameters()) grads_list.append(grads) loss_list.append(meta_loss) print('meta_loss:', meta_loss) # outer optimization: init = init - alpha*average(dashes_gradients) with torch.no_grad(): for (param_init, grad1, grad2, grad3) in zip(model_init.parameters(), grads_list[0], grads_list[1], grads_list[2]): param_init -= alpha*(grad1+grad2+grad3) plt.plot(loss_list) plt.show()
Editor is loading...