Untitled

 avatar
unknown
plain_text
14 days ago
1.8 kB
4
Indexable
import os
import torch


class CheckpointSaver():
    def __init__(self, checkpoints_dir, save_num_best=5):
        self.checkpoints_dir = checkpoints_dir
        self.save_num_best = save_num_best
        self.best_checkpoints = []

    def save(self, state_dicts, metric_value, epoch):
        torch.save(state_dicts, os.path.join(self.checkpoints_dir, "latest_tmp.pt"))

        latest_files = [f for f in os.listdir(self.checkpoints_dir) if "latest_epoch_" in f]
        if latest_files:
            assert len(latest_files) == 1, 'Multiple latest_epoch_ files found'
            os.remove(os.path.join(self.checkpoints_dir, latest_files[0]))

        os.rename(
            src=os.path.join(self.checkpoints_dir, "latest_tmp.pt"),
            dst=os.path.join(self.checkpoints_dir, f"latest_epoch_{epoch}.pt")
        )

        if metric_value:
            checkpoint_name = f"epoch_{epoch}_{metric_value:.4f}.pt"
            if len(self.best_checkpoints) < self.save_num_best:

                checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name)
                torch.save(state_dicts, checkpoint_path)

                self.best_checkpoints.append({"metric": metric_value, "path": checkpoint_path})
                self.best_checkpoints.sort(reverse=False, key=lambda x: x["metric"])
            else:
                if metric_value <= self.best_checkpoints[-1]["metric"]:
                    checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name)
                    torch.save(state_dicts, checkpoint_path)

                    self.best_checkpoints.append({"metric": metric_value, "path": checkpoint_path})
                    self.best_checkpoints.sort(reverse=False, key=lambda x: x["metric"])

                    os.remove(self.best_checkpoints.pop()["path"])
Editor is loading...
Leave a Comment