Optimization

mail@pastecode.io avatar
unknown
python
2 months ago
946 B
1
Indexable
Never
from torch.optim.lr_scheduler import LambdaLR

def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.parameters(), lr=min_lr)
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.min_lr, weight_decay=1e-4, betas=(0.9, 0.95)
        )

        # Lambda function for the linear warmup
        def lr_lambda(step):
            # Use hard negative step if relevant
            if self.use_hard_negatives:
                step = self.corrected_step
            lr = self.min_lr + (self.max_lr - self.min_lr) * min(
                1.0, step / self.warmup_steps
            )  # * (self.decay_rate ** (step / self.warmup_steps))
            return lr / self.min_lr

        # Scheduler
        scheduler = {
            "scheduler": LambdaLR(optimizer, lr_lambda),
            "interval": "step",
            "frequency": 1,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler}