Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
4.3 kB
2
Indexable
Never
def eval_verification(unknown_images, known_images, model, model2, similarity, batch_size= config['batch_size'], mode='val', threshold=0.5): 

    unknown_feats,unknown_feats2, known_feats,known_feats2 = [], [],[],[]

    batch_bar = tqdm(total=len(unknown_images)//batch_size, dynamic_ncols=True, position=0, leave=False, desc=mode)
    model.eval()

    # We load the images as batches for memory optimization and avoiding CUDA OOM errors
    for i in range(0, unknown_images.shape[0], batch_size):
        unknown_batch = unknown_images[i:i+batch_size] # Slice a given portion upto batch_size
        
        with torch.no_grad():
            unknown_feat = model(unknown_batch.float().to(DEVICE), return_feats=True) #Get features from model
            unknown_feat2 = model2(unknown_batch.float().to(DEVICE), return_feats=True) #Get features from model
            # unknown_feat = unknown_feat + unknown_feat2
            unknown_feat = torch.nn.functional.normalize(unknown_feat, dim=1)
            unknown_feat2 = torch.nn.functional.normalize(unknown_feat2, dim=1)



        unknown_feats.append(unknown_feat)
        unknown_feats2.append(unknown_feat2)

        batch_bar.update()
    
    batch_bar.close()
    
    batch_bar = tqdm(total=len(known_images)//batch_size, dynamic_ncols=True, position=0, leave=False, desc=mode)
    
    for i in range(0, known_images.shape[0], batch_size):
        known_batch = known_images[i:i+batch_size] 
        with torch.no_grad():
              known_feat = model(known_batch.float().to(DEVICE), return_feats=True)
              known_feat2 = model2(known_batch.float().to(DEVICE), return_feats=True)
              # known_feat = known_feat + known_feat2
              # print(known_feat)
              known_feat = torch.nn.functional.normalize(known_feat, dim=1)

              known_feat2 = torch.nn.functional.normalize(known_feat2, dim=1)
          
        known_feats.append(known_feat)
        known_feats2.append(known_feat2)

        batch_bar.update()

    batch_bar.close()

    # Concatenate all the batches
    unknown_feats = torch.cat(unknown_feats, dim=0)
    unknown_feats2 = torch.cat(unknown_feats2, dim=0)

    known_feats = torch.cat(known_feats, dim=0)
    known_feats2 = torch.cat(known_feats2, dim=0)


    similarity_values = torch.stack([similarity(unknown_feats, known_feature) for known_feature in known_feats])
    similarity_values2 = torch.stack([similarity(unknown_feats2, known_feature2) for known_feature2 in known_feats2])

    similarity_values = similarity_values + similarity_values2
    # Print the inner list comprehension in a separate cell - what is really happening?

    max_similarity_values, predictions = similarity_values.max(0) #Why are we doing an max here, where are the return values?
    max_similarity_values, predictions = max_similarity_values.cpu().numpy(), predictions.cpu().numpy()


    # Note that in unknown identities, there are identities without correspondence in known identities.
    # Therefore, these identities should be not similar to all the known identities, i.e. max similarity will be below a certain 
    # threshold compared with those identities with correspondence.

    # In early submission, you can ignore identities without correspondence, simply taking identity with max similarity value
    pred_id_strings = [known_paths[i] for i in predictions] # Map argmax indices to identity strings
    
    # After early submission, remove the previous line and uncomment the following code 
    # print(max_similarity_values)
    # threshold = 0.4# Choose a proper threshold
    NO_CORRESPONDENCE_LABEL = 'n000000'
    pred_id_strings = []
    for idx, prediction in enumerate(predictions):
      if max_similarity_values[idx] < threshold: # why < ? Thank about what is your similarity metric
        pred_id_strings.append(NO_CORRESPONDENCE_LABEL)
      else:
        pred_id_strings.append(known_paths[prediction])
    accuracy = 0
    if mode == 'val':
      true_ids = pd.read_csv('/content/data/11-785-s23-hw2p2-verification/verification_dev.csv')['label'].tolist()
      accuracy = accuracy_score(pred_id_strings, true_ids)
      print("Verification Accuracy = {}".format(accuracy))
    
    return pred_id_strings, accuracy