Untitled

mail@pastecode.io avatar
unknown
plain_text
3 years ago
592 B
1
Indexable
def get_cosine_schedule_with_warmup(
	optimizer,
	num_warmup_steps,
	num_training_steps,
	num_cycles = 0.5,
	last_epoch = -1,
):

	def lr_lambda(current_step):
		# Warmup
		if current_step < num_warmup_steps:
			return float(current_step) / float(max(1, num_warmup_steps))
		# decadence
		progress = float(current_step - num_warmup_steps) / float(
			max(1, num_training_steps - num_warmup_steps)
		)
		return max(
			0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
		)

	return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)