Untitled

 avatar
wyc1230
plain_text
3 days ago
2.5 kB
5
Indexable
import pytest
from prometheus_client import CollectorRegistry, Counter, Histogram
from metrics_utils import MetricsUtils  # Adjust import as needed

@pytest.fixture
def prometheus_registry():
    """Fixture to provide an isolated Prometheus registry for testing"""
    return CollectorRegistry()

def test_record_step_latency(prometheus_registry):
    """Test if step latency is recorded correctly"""
    # Create a test instance of Histogram with an isolated registry
    test_histogram = Histogram(
        "tetris_step_latency_test",
        "Latency of major steps in the tetris execution",
        labelnames=["step", "algorithm", "centerId", "groupId"],
        registry=prometheus_registry,
        buckets=MetricsUtils.DEFAULT_BUCKETS
    )

    # Override class attribute for testing
    MetricsUtils.tetris_step_latency = test_histogram

    # Act: Call record_step_latency
    MetricsUtils.record_step_latency("Model execution", "TestAlgo", 123, "group1", 500)

    # Assert: Check if the metric was recorded
    metric_samples = test_histogram.collect()[0].samples
    assert any(sample.labels == {
        "step": "Model execution", 
        "algorithm": "TestAlgo", 
        "centerId": "123",  # Labels are stored as strings in Prometheus
        "groupId": "group1"
    } and sample.value > 0 for sample in metric_samples), "Latency metric not recorded properly"

def test_record_model_exception(prometheus_registry):
    """Test if model exception count increments correctly"""
    # Create a test instance of Counter with an isolated registry
    test_counter = Counter(
        "tetris_model_exception_count_test",
        "model exception count",
        labelnames=["algorithm", "centerId", "groupId", "errorType"],
        registry=prometheus_registry
    )

    # Override class attribute for testing
    MetricsUtils.model_exception_count = test_counter

    # Act: Call record_model_exception
    MetricsUtils.record_model_exception("TestAlgo", 123, "group1", "ValueError")

    # Assert: Check if the metric was incremented
    metric_samples = test_counter.collect()[0].samples
    assert any(sample.labels == {
        "algorithm": "TestAlgo",
        "centerId": "123",  # Labels are stored as strings in Prometheus
        "groupId": "group1",
        "errorType": "ValueError"
    } and sample.value == 1.0 for sample in metric_samples), "Exception counter not incremented properly"
Leave a Comment