project_code
unknown
plain_text
2 years ago
3.2 kB
5
Indexable
rec_epochs = config['recognition_epochs'] rec_logging = config['recognition_logging'] rec_log_interval = config["recognition_log_wieghts_interval"] if rec_logging: # read secrets for cometml logging with open('secrets.json') as secrets_file: secrets = json.load(secrets_file) # init experimenxt experiment = Experiment( api_key=secrets["api_key"], project_name=secrets["project_name"], workspace="reu-ds-club", tags=["recognition"], ) hyper_params = { "model_name": config["model"], "use_colab": config['use_colab'], "epochs": rec_epochs, "batch_size": config['batch_size'], "image_size": config['img_size'], } experiment.log_parameters(hyper_params) for epoch in range(rec_epochs): epoch_loss = 0 for i, triplet in enumerate(train_recognition_dataloader): anc, pos, neg = triplet preds = triplet_model(anc.to(device), pos.to(device), neg.to(device)) loss_val = loss_for_recognition(*preds) optimizer.zero_grad() loss_val.backward() optimizer.step() epoch_loss += loss_val.item() break val_loss = utils.validate_model_rec(triplet_model, test_recognition_dataloader, loss_for_recognition, device) print(f"Epoch: {epoch+12}\tLoss: {epoch_loss / len(train_recognition_dataloader)}\tVal loss: {val_loss}") early_stopper(val_loss) if early_stopper.early_stop: print("Early stopping") break checkpoint_filename = os.path.join(checkpoint_dir, f"{epoch+12}_checkpoint_recognition.pth") torch.save({ 'epoch': epoch, 'model_state_dict': triplet_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss_train': epoch_loss, 'loss_val': val_loss, }, checkpoint_filename) if rec_logging: experiment.log_metric("loss", epoch_loss, step=epoch) # logging model weights (accorging to log_interval + last epoch) if rec_logging and (epoch % log_interval == 0 or epoch == rec_epochs-1): torch.save(triplet_model, 'rec_model.pth') experiment.log_model(name = f"rec_model-epoch-{epoch}", file_or_folder = 'rec_model.pth', file_name = f"rec_model-epoch-{epoch}") experiment.log_asset(file_data = 'rec_model.pth', file_name = f"rec_model-epoch-{epoch}") print("save model") if rec_logging: experiment.end() --------- def validate_model_rec(model, test_dataloader, loss_fn, device): model.eval() # Переключение модели в режим оценки total_loss = 0.0 with torch.no_grad(): # Выключение вычисления градиентов for sample in test_dataloader: ans, pos, neg = sample # sample[0].to(device), sample[1].to(device) preds = model(ans.to(device), pos.to(device), neg.to(device)) loss = loss_fn(*preds) total_loss += loss.item() avg_loss = total_loss / len(test_dataloader) # Средняя потеря за эпоху model.train() # Вернуть модель в режим обучения return avg_loss
Editor is loading...