Untitled

mail@pastecode.io avatar
unknown
python
8 months ago
1.3 kB
5
Indexable
Never
import time
from typing import List, Optional


class Dataset:
    @staticmethod
    def load_from_file(filename: str) -> 'Dataset': ...
    def get_rows(self, number_of_rows: Optional[int]) -> List: ...
    def split(self, test_size=0.2) -> ('Dataset', 'Dataset', 'Dataset', 'Dataset'): ...


class MLModel:
    def __init__(self, model_type: str): ...
    def fit(self, X: Dataset, y: Dataset) -> None: ...
    def predict(self, series: List = []) -> Optional[List]: ...


def accuracy(test_data: Dataset, predictions: List = []) -> float: ...


def benchmark():
    for dataset in [Dataset.load_from_file("iris.csv"), Dataset.load_from_file("credit_risk.csv")]:
        for ml_model in [MLModel(model_type="classification"), MLModel(model_type="regression")]:
            X_train, X_test, y_train, y_test = dataset.split(test_size=0.3)

            start_time = time.time()
            ml_model.fit(X_train, y_train)
            train_time = time.time() - start_time

            start_time = time.time()
            y_pred = ml_model.predict(X_test.get_rows())
            test_time = time.time() - start_time

            accuracy = accuracy(y_test, y_pred)

            print(f"Model {ml_model} accuracy for dataset {dataset} is {accuracy}.")


if __name__ == '__main__':
    benchmark()
Leave a Comment