From 2ce940259362e0b916c370ab9eeeb3c411867a9c Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Thu, 3 Jul 2025 17:11:52 +0200 Subject: [PATCH] update: code review - change type from str to Field with description - remove test from test_distrubution.py, keep in test_providers.py Signed-off-by: Wen Zhou --- llama_stack/apis/providers/providers.py | 6 +- tests/unit/distribution/test_distribution.py | 125 ------------------- tests/unit/distribution/test_providers.py | 26 ++++ 3 files changed, 30 insertions(+), 127 deletions(-) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index b8bba6919..8c1a8d73e 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -6,7 +6,7 @@ from typing import Any, Protocol, runtime_checkable -from pydantic import BaseModel, HttpUrl, field_validator +from pydantic import BaseModel, Field, HttpUrl, field_validator from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -19,7 +19,9 @@ class ProviderInfo(BaseModel): provider_type: str config: dict[str, Any] health: HealthResponse - metrics: str | None = None # define as string type than httpurl for openapi compatibility + metrics: str | None = Field( + default=None, description="endpoint for metrics from providers. Must be a valid HTTP URL if provided." + ) @field_validator("metrics") @classmethod diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index a77c5cb18..ae24602d7 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -13,134 +13,9 @@ from pydantic import BaseModel, Field, ValidationError from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.providers.datatypes import ProviderSpec -class TestProviderMetrics: - """Test suite for provider metrics.""" - - @pytest.mark.asyncio - async def test_provider_with_valid_metrics_in_config(self): - """Test provider with valid metrics in config.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - response = await provider_impl.list_providers() - assert len(response.data) == 1 - - provider_info = response.data[0] - assert provider_info.provider_id == "test_provider" - assert provider_info.provider_type == "test_type" - assert provider_info.metrics == "http://localhost:9090/metrics" - assert "metrics" not in provider_info.config - assert provider_info.config["url"] == "http://localhost:8000" - - @pytest.mark.asyncio - async def test_provider_with_invalid_metrics_in_config(self): - """Test invalid metrics in config fails when access /providers.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - with pytest.raises(ValidationError): - await provider_impl.list_providers() - - @pytest.mark.asyncio - async def test_provider_with_invalid_metrics_in_config2(self): - """Test invalid metrics2 in config fails when access /providers.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": 123}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - with pytest.raises(ValidationError): - await provider_impl.list_providers() - - @pytest.mark.asyncio - async def test_provider_without_metrics_in_config(self): - """Test provider without metrics in config returns None.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"} - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - response = await provider_impl.list_providers() - assert len(response.data) == 1 - - provider_info = response.data[0] - assert provider_info.provider_id == "test_provider" - assert provider_info.provider_type == "test_type" - assert provider_info.metrics is None - assert "metrics" not in provider_info.config - assert provider_info.config["url"] == "http://localhost:8000" - - @pytest.mark.asyncio - async def test_inspect_provider_with_metrics(self): - """Test inspect_provider returns correct metrics info.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - # Test the inspect_provider API - provider_info = await provider_impl.inspect_provider("test_provider") - assert provider_info.provider_id == "test_provider" - assert provider_info.metrics == "http://localhost:9090/metrics" - assert "metrics" not in provider_info.config - - class SampleConfig(BaseModel): foo: str = Field( default="bar", diff --git a/tests/unit/distribution/test_providers.py b/tests/unit/distribution/test_providers.py index 0a81c4b31..4f1ee5cb9 100644 --- a/tests/unit/distribution/test_providers.py +++ b/tests/unit/distribution/test_providers.py @@ -6,6 +6,7 @@ import pytest +from pydantic import ValidationError from llama_stack.apis.providers import ProviderInfo from llama_stack.distribution.datatypes import Provider, StackRunConfig @@ -37,6 +38,23 @@ class TestProviderImpl: ) return ProviderImplConfig(run_config=run_config) + @pytest.fixture + def mock_config_malformed_metrics(self): + """Create a mock configuration with invalid metrics URL for testing.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider_malformed_metrics", + provider_type="test_type3", + config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, + ), + ] + }, + ) + return ProviderImplConfig(run_config=run_config) + @pytest.fixture def mock_deps(self): """Create mock dependencies.""" @@ -111,3 +129,11 @@ class TestProviderImpl: with pytest.raises(ValueError, match="Provider nonexistent not found"): await provider_impl.inspect_provider("nonexistent") + + @pytest.mark.asyncio + async def test_inspect_provider_malformed_metrics(self, mock_config_malformed_metrics, mock_deps): + """Test inspect_provider with invalid metrics URL raises validation error.""" + provider_impl = ProviderImpl(mock_config_malformed_metrics, mock_deps) + + with pytest.raises(ValidationError): + await provider_impl.inspect_provider("test_provider_malformed_metrics")