test.py

 avatar
unknown
plain_text
6 months ago
5.3 kB
3
Indexable
# CREATE MODELS AND PREDICT
"""
Run this file by :
python test.py --model_weights path1_to_your_model1.pt path2_to_your_model2.pt --test_path data/small_val/images

"""
import argparse
from ultralytics import YOLO
import os
import time
from tqdm import tqdm

from utils.wbf import *
from utils.visualize import *


# create model lists and namesa
def create_models(model_weights_list):
    """
    - Create a list of models used to predict
    - Params:
        model_weights_list: list of paths to weights files
    """
    models_list = []
    model_names_list = []
    for model_weights in model_weights_list:
        model = YOLO(model_weights)
        model_name = os.path.basename(model_weights)

        models_list.append(model)
        model_names_list.append(model_name)

    return models_list, model_names_list


def get_models_predictions(model_weights_list, test_path, single_image=False):
    """
    Get predictions of models for all images in test_path or just single image

    Params:
    - model_weights_list: list of paths to model weights
    - test_path: path to test_folder or single image
    - single_image: predict for single image or not, (default = False)
    """
    # check valid path
    for path in model_weights_list:
        if not os.path.exists(path):
            print(f"\nModel with path {path} does not exist!")
            return None
    if not os.path.exists(test_path):
        print("Test path does not exist!")
        return None

    # create models list and get model names
    models_list, model_names_list = create_models(model_weights_list)

    # store models predictions
    predictions = {}

    for idx, model in enumerate(models_list):
        predictions[model_names_list[idx]] = {}
        model.overrides["verbose"] = False  # make terminal shorter

        # Predict for every image in test_path
        if not single_image:
            print(f"Processing with model: {model_names_list[idx]} ")

            # start counting time
            start_time = time.time()

            for image_name in tqdm(
                os.listdir(test_path), desc=f"Processing {model_names_list[idx]}"
            ):
                image_path = os.path.join(test_path, image_name)
                results = detect_image(image_path, model)

                predictions[model_names_list[idx]][image_name] = results

            # Print process time for current model
            elapsed_time = time.time() - start_time
            print(f"Completed {model_names_list[idx]} in {elapsed_time:.2f} seconds")

        # Predict for single image
        else:
            results = detect_image(test_path, model)
            predictions[model_names_list[idx]][os.path.basename(test_path)] = results

    # predictions[model_name][image_name] = [ [x1,y1,x2,y2,img_w, img_h, label, score], [..] ] non-normalized
    return predictions


def run_test(model_weights_list, test_path, plot=True):
    predictions = get_models_predictions(
        model_weights_list, test_path, single_image=False
    )

    models_names_list = [os.path.basename(path) for path in model_weights_list]

    # fuse results
    results = fuse(
        models_names_list,
        test_path,
        predictions,
        single_image=False,
        iou_thr=0.5,
        skip_box_thr=0.05,
    )

    # visualize
    if plot:
        images = os.listdir(test_path)
        sample_image = images[10]
        visualize(os.path.join(test_path, sample_image), results[sample_image])

    return results


def run_test_on_single_image(model_weights_list, image_path, plot=True):
    predictions = get_models_predictions(
        model_weights_list, image_path, single_image=True
    )
    models_names_list = [os.path.basename(path) for path in model_weights_list]

    # fuse results
    results = fuse(
        models_names_list,
        image_path,
        predictions,
        single_image=True,
        iou_thr=0.5,
        skip_box_thr=0.05,
    )

    # visualize
    if plot:
        visualize(image_path, results[os.path.basename(image_path)])

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="YOLO Model Predictions")
    parser.add_argument(
        "--model_weights",
        nargs="+",
        default=["weights/v10s/v10s_best5.pt", "weights/v11s/v11s_best5.pt"],
        help="List of paths to model weight files",
    )
    parser.add_argument(
        "--test_path",
        default="../data/small_val/images",
        help="Path to test images folder or a single image",
    )
    # parser.add_argument(
    #     "--single_image",
    #     action="store_true",
    #     help="Flag to run prediction on a single image",
    # )
    args = parser.parse_args()

    # CALL FUNCTIONS
    # run_test(model_weights_list, test_path)

    sample_image = os.path.join(args.test_path, os.listdir(args.test_path)[10])
    run_test_on_single_image(args.model_weights_list, sample_image)

    # sample_images = os.listdir(test_path)
    # for sample_image in sample_images[:10]:
    #     run_test_on_single_image(model_weights_list, os.path.join(test_path, sample_image))
Editor is loading...
Leave a Comment