FO-MAML-Quadratic

mail@pastecode.io avatar
unknown
python
3 years ago
2.1 kB
7
Indexable
Never
import torch
import matplotlib.pyplot as plt
from copy import deepcopy

# torch.manual_seed(0)

# define dataset: quadratic functions: Support (first 9 points) and Query Sets (last 3 points)
task1 = {}
task1['x'] = [1,2,3,-3,-2,-1, 8, 9, -8, -9, 15, -15]
task1['y'] = [4*x^2+3*x+2 for x in task1['x']] # y = 4x2+3x+2

task2 = {}
task2['x'] = [1,2,3,-3,-2,-1, 8, 9, -8, -9, 15, -15]
task2['y'] = [4*x^2-2*x+4 for x in task1['x']] # y = 4x2-2x+4

task3 = {}
task3['x'] = [1,2,3,-3,-2,-1, 8, 9, -8, -9, 15, -15]
task3['y'] = [4*x^2+8*x-12 for x in task1['x']] # y = 4x2+8x-12

alpha = 0.00003
beta = 0.0001

# define two models
model_init = torch.nn.Sequential(torch.nn.Linear(1,5), torch.nn.ReLU(), torch.nn.Linear(5,1))
model_dash_es = [deepcopy(model_init), deepcopy(model_init), deepcopy(model_init)]

loss_list = []

# for outer loop
for i in range(5000):
	meta_loss = 0
	grads_list = []
	
	# for inner loop
	for j, task in enumerate([task1, task2, task3]):
		model_dash = model_dash_es[j]
		y_preds = model_init(torch.FloatTensor(task['x'][:9]).unsqueeze(1))
		loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][:9]).unsqueeze(1))
		model_init.zero_grad()
		model_dash.zero_grad()
		grads = torch.autograd.grad(loss, model_init.parameters())

		# inner optimization: dash = init - beta*grad
		with torch.no_grad():
			for param_init, param_dash, grad in zip(model_init.parameters(), model_dash.parameters(), grads):
				param_dash += -param_dash + param_init - beta*grad

		y_preds = model_dash(torch.FloatTensor(task['x'][9:]).unsqueeze(1))
		query_loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][9:]).unsqueeze(1))
		meta_loss += query_loss.item()
		grads = torch.autograd.grad(query_loss, model_dash.parameters())
		grads_list.append(grads)

	loss_list.append(meta_loss)
	print('meta_loss:', meta_loss)

	# outer optimization: init = init - alpha*average(dashes_gradients)
	with torch.no_grad():
		for (param_init, grad1, grad2, grad3) in zip(model_init.parameters(), grads_list[0], grads_list[1], grads_list[2]):
			param_init -= alpha*(grad1+grad2+grad3)

plt.plot(loss_list)
plt.show()