Untitled
unknown
plain_text
14 days ago
1.8 kB
4
Indexable
import os import torch class CheckpointSaver(): def __init__(self, checkpoints_dir, save_num_best=5): self.checkpoints_dir = checkpoints_dir self.save_num_best = save_num_best self.best_checkpoints = [] def save(self, state_dicts, metric_value, epoch): torch.save(state_dicts, os.path.join(self.checkpoints_dir, "latest_tmp.pt")) latest_files = [f for f in os.listdir(self.checkpoints_dir) if "latest_epoch_" in f] if latest_files: assert len(latest_files) == 1, 'Multiple latest_epoch_ files found' os.remove(os.path.join(self.checkpoints_dir, latest_files[0])) os.rename( src=os.path.join(self.checkpoints_dir, "latest_tmp.pt"), dst=os.path.join(self.checkpoints_dir, f"latest_epoch_{epoch}.pt") ) if metric_value: checkpoint_name = f"epoch_{epoch}_{metric_value:.4f}.pt" if len(self.best_checkpoints) < self.save_num_best: checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name) torch.save(state_dicts, checkpoint_path) self.best_checkpoints.append({"metric": metric_value, "path": checkpoint_path}) self.best_checkpoints.sort(reverse=False, key=lambda x: x["metric"]) else: if metric_value <= self.best_checkpoints[-1]["metric"]: checkpoint_path = os.path.join(self.checkpoints_dir, checkpoint_name) torch.save(state_dicts, checkpoint_path) self.best_checkpoints.append({"metric": metric_value, "path": checkpoint_path}) self.best_checkpoints.sort(reverse=False, key=lambda x: x["metric"]) os.remove(self.best_checkpoints.pop()["path"])
Editor is loading...
Leave a Comment