project_code

mail@pastecode.io avatar
unknown
plain_text
7 months ago
3.2 kB
2
Indexable
Never
rec_epochs = config['recognition_epochs']
rec_logging = config['recognition_logging']
rec_log_interval = config["recognition_log_wieghts_interval"]

if rec_logging: 
    # read secrets for cometml logging
    with open('secrets.json') as secrets_file:
        secrets = json.load(secrets_file)

    # init experimenxt
    experiment = Experiment(
        api_key=secrets["api_key"],
        project_name=secrets["project_name"],
        workspace="reu-ds-club", 
        tags=["recognition"],
    )

    hyper_params = {
        "model_name": config["model"],
        "use_colab": config['use_colab'], 
        "epochs": rec_epochs,
        "batch_size": config['batch_size'], 
        "image_size": config['img_size'], 
    }

    experiment.log_parameters(hyper_params)

for epoch in range(rec_epochs):
    epoch_loss = 0

    for i, triplet in enumerate(train_recognition_dataloader):

        anc, pos, neg = triplet

        preds = triplet_model(anc.to(device), pos.to(device), neg.to(device))

        loss_val = loss_for_recognition(*preds)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
        
        epoch_loss += loss_val.item()
        break
    
    val_loss = utils.validate_model_rec(triplet_model, test_recognition_dataloader, loss_for_recognition, device)

    print(f"Epoch: {epoch+12}\tLoss: {epoch_loss / len(train_recognition_dataloader)}\tVal loss: {val_loss}")

    early_stopper(val_loss)
    if early_stopper.early_stop:
        print("Early stopping")
        break
        
    checkpoint_filename = os.path.join(checkpoint_dir, f"{epoch+12}_checkpoint_recognition.pth")

    torch.save({
        'epoch': epoch,
        'model_state_dict': triplet_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_train': epoch_loss,
        'loss_val': val_loss,
    }, checkpoint_filename)

    if rec_logging:
        experiment.log_metric("loss", epoch_loss, step=epoch)
    
    # logging model weights (accorging to log_interval + last epoch)
    if rec_logging and (epoch % log_interval == 0 or epoch == rec_epochs-1):
        torch.save(triplet_model, 'rec_model.pth')
        experiment.log_model(name = f"rec_model-epoch-{epoch}", file_or_folder = 'rec_model.pth', file_name = f"rec_model-epoch-{epoch}")
        experiment.log_asset(file_data = 'rec_model.pth', file_name = f"rec_model-epoch-{epoch}")
        print("save model")


if rec_logging:
    experiment.end()


---------


def validate_model_rec(model, test_dataloader, loss_fn, device):
    model.eval()  # Переключение модели в режим оценки
    total_loss = 0.0
    with torch.no_grad():  # Выключение вычисления градиентов
        for sample in test_dataloader:
            ans, pos, neg = sample # sample[0].to(device), sample[1].to(device)
            preds = model(ans.to(device), pos.to(device), neg.to(device))
            loss = loss_fn(*preds)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(test_dataloader)  # Средняя потеря за эпоху
    model.train()  # Вернуть модель в режим обучения
    return avg_loss