Untitled
unknown
plain_text
4 years ago
2.5 kB
6
Indexable
import torch import matplotlib.pyplot as plt from copy import deepcopy import random torch.manual_seed(0) # define dataset: quadratic functions: Support (first 9 points) and Query Sets (last 3 points) multiplier = 10 task1 = {} task1['x'] = [random.random()*multiplier for _ in range(100)] A, B, C = 4, 3, 2 # A, B, C = random.random()*multiplier, random.random()*multiplier, random.random()*multiplier task1['y'] = [A*(x**2)+B*x+C for x in task1['x']] # y = 4x2+3x+2 task2 = {} task2['x'] = [random.random()*multiplier for _ in range(100)] A, B, C = 4, -2, 4 # A, B, C = random.random()*multiplier, random.random()*multiplier, random.random()*multiplier task2['y'] = [A*(x**2)+B*x+C for x in task1['x']] # y = 4x2-2x+4 task3 = {} task3['x'] =[random.random()*multiplier for _ in range(100)] A, B, C = 4, 8, -12 # A, B, C = random.random()*multiplier, random.random()*multiplier, random.random()*multiplier task3['y'] = [A*(x**2)+B*x+C for x in task1['x']] # y = 4x2+8x-12 alpha = 0.00003 beta = 0.00001 # define two models model_init = torch.nn.Sequential(torch.nn.Linear(1,32), torch.nn.ReLU(), torch.nn.Linear(32,1)) model_dash_es = [deepcopy(model_init), deepcopy(model_init), deepcopy(model_init)] loss_list = [] # for outer loop for i in range(700): meta_loss = 0 grads_list = [] # for inner loop for j, task in enumerate([task1, task2, task3]): model_dash = model_dash_es[j] y_preds = model_init(torch.FloatTensor(task['x'][:]).unsqueeze(1)) loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][:]).unsqueeze(1)) model_init.zero_grad() model_dash.zero_grad() grads = torch.autograd.grad(loss, model_init.parameters()) # inner optimization: dash = init - beta*grad with torch.no_grad(): for param_init, param_dash, grad in zip(model_init.parameters(), model_dash.parameters(), grads): param_dash += -param_dash + param_init - beta*grad y_preds = model_dash(torch.FloatTensor(task['x'][-50:]).unsqueeze(1)) query_loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][-50:]).unsqueeze(1)) meta_loss += query_loss.item()/(50*3) grads = torch.autograd.grad(query_loss, model_dash.parameters()) grads_list.append(grads) loss_list.append(meta_loss) print('meta_loss:', meta_loss) # outer optimization: init = init - alpha*average(dashes_gradients) with torch.no_grad(): for (param_init, grad1, grad2, grad3) in zip(model_init.parameters(), grads_list[0], grads_list[1], grads_list[2]): param_init -= alpha*(grad1+grad2+grad3) plt.plot(loss_list) plt.show()
Editor is loading...