Untitled

mail@pastecode.io avatar
unknown
plain_text
3 years ago
2.5 kB
2
Indexable
Never
import torch
import matplotlib.pyplot as plt
from copy import deepcopy
import random

torch.manual_seed(0)

# define dataset: quadratic functions: Support (first 9 points) and Query Sets (last 3 points)
multiplier = 10

task1 = {}
task1['x'] = [random.random()*multiplier for _ in range(100)]
A, B, C = 4, 3, 2
# A, B, C = random.random()*multiplier, random.random()*multiplier, random.random()*multiplier
task1['y'] = [A*(x**2)+B*x+C for x in task1['x']] # y = 4x2+3x+2

task2 = {}
task2['x'] = [random.random()*multiplier for _ in range(100)]
A, B, C = 4, -2, 4
# A, B, C = random.random()*multiplier, random.random()*multiplier, random.random()*multiplier
task2['y'] = [A*(x**2)+B*x+C for x in task1['x']] # y = 4x2-2x+4

task3 = {}
task3['x'] =[random.random()*multiplier for _ in range(100)]
A, B, C = 4, 8, -12
# A, B, C = random.random()*multiplier, random.random()*multiplier, random.random()*multiplier
task3['y'] = [A*(x**2)+B*x+C for x in task1['x']] # y = 4x2+8x-12

alpha = 0.00003
beta = 0.00001

# define two models
model_init = torch.nn.Sequential(torch.nn.Linear(1,32), torch.nn.ReLU(), torch.nn.Linear(32,1))
model_dash_es = [deepcopy(model_init), deepcopy(model_init), deepcopy(model_init)]

loss_list = []

# for outer loop
for i in range(700):
	meta_loss = 0
	grads_list = []

	# for inner loop
	for j, task in enumerate([task1, task2, task3]):
		model_dash = model_dash_es[j]
		y_preds = model_init(torch.FloatTensor(task['x'][:]).unsqueeze(1))
		loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][:]).unsqueeze(1))
		model_init.zero_grad()
		model_dash.zero_grad()
		grads = torch.autograd.grad(loss, model_init.parameters())

		# inner optimization: dash = init - beta*grad
		with torch.no_grad():
			for param_init, param_dash, grad in zip(model_init.parameters(), model_dash.parameters(), grads):
				param_dash += -param_dash + param_init - beta*grad

		y_preds = model_dash(torch.FloatTensor(task['x'][-50:]).unsqueeze(1))
		query_loss = torch.nn.MSELoss()(y_preds, torch.FloatTensor(task['y'][-50:]).unsqueeze(1))
		meta_loss += query_loss.item()/(50*3)
		grads = torch.autograd.grad(query_loss, model_dash.parameters())
		grads_list.append(grads)

	loss_list.append(meta_loss)
	print('meta_loss:', meta_loss)

	# outer optimization: init = init - alpha*average(dashes_gradients)
	with torch.no_grad():
		for (param_init, grad1, grad2, grad3) in zip(model_init.parameters(), grads_list[0], grads_list[1], grads_list[2]):
			param_init -= alpha*(grad1+grad2+grad3)

plt.plot(loss_list)
plt.show()