project_code
unknown
plain_text
2 years ago
3.2 kB
9
Indexable
rec_epochs = config['recognition_epochs']
rec_logging = config['recognition_logging']
rec_log_interval = config["recognition_log_wieghts_interval"]
if rec_logging:
# read secrets for cometml logging
with open('secrets.json') as secrets_file:
secrets = json.load(secrets_file)
# init experimenxt
experiment = Experiment(
api_key=secrets["api_key"],
project_name=secrets["project_name"],
workspace="reu-ds-club",
tags=["recognition"],
)
hyper_params = {
"model_name": config["model"],
"use_colab": config['use_colab'],
"epochs": rec_epochs,
"batch_size": config['batch_size'],
"image_size": config['img_size'],
}
experiment.log_parameters(hyper_params)
for epoch in range(rec_epochs):
epoch_loss = 0
for i, triplet in enumerate(train_recognition_dataloader):
anc, pos, neg = triplet
preds = triplet_model(anc.to(device), pos.to(device), neg.to(device))
loss_val = loss_for_recognition(*preds)
optimizer.zero_grad()
loss_val.backward()
optimizer.step()
epoch_loss += loss_val.item()
break
val_loss = utils.validate_model_rec(triplet_model, test_recognition_dataloader, loss_for_recognition, device)
print(f"Epoch: {epoch+12}\tLoss: {epoch_loss / len(train_recognition_dataloader)}\tVal loss: {val_loss}")
early_stopper(val_loss)
if early_stopper.early_stop:
print("Early stopping")
break
checkpoint_filename = os.path.join(checkpoint_dir, f"{epoch+12}_checkpoint_recognition.pth")
torch.save({
'epoch': epoch,
'model_state_dict': triplet_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss_train': epoch_loss,
'loss_val': val_loss,
}, checkpoint_filename)
if rec_logging:
experiment.log_metric("loss", epoch_loss, step=epoch)
# logging model weights (accorging to log_interval + last epoch)
if rec_logging and (epoch % log_interval == 0 or epoch == rec_epochs-1):
torch.save(triplet_model, 'rec_model.pth')
experiment.log_model(name = f"rec_model-epoch-{epoch}", file_or_folder = 'rec_model.pth', file_name = f"rec_model-epoch-{epoch}")
experiment.log_asset(file_data = 'rec_model.pth', file_name = f"rec_model-epoch-{epoch}")
print("save model")
if rec_logging:
experiment.end()
---------
def validate_model_rec(model, test_dataloader, loss_fn, device):
model.eval() # Переключение модели в режим оценки
total_loss = 0.0
with torch.no_grad(): # Выключение вычисления градиентов
for sample in test_dataloader:
ans, pos, neg = sample # sample[0].to(device), sample[1].to(device)
preds = model(ans.to(device), pos.to(device), neg.to(device))
loss = loss_fn(*preds)
total_loss += loss.item()
avg_loss = total_loss / len(test_dataloader) # Средняя потеря за эпоху
model.train() # Вернуть модель в режим обучения
return avg_lossEditor is loading...