From 0b051c037b005ab7d2331d68c3d2eb25638b0d28 Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Mon, 30 Jun 2025 10:12:48 +0200 Subject: [PATCH] update: add validation on non-string type Signed-off-by: Wen Zhou --- llama_stack/apis/providers/providers.py | 7 +++--- tests/unit/distribution/test_distribution.py | 23 +++++++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 3d7955970..b8bba6919 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -26,9 +26,10 @@ class ProviderInfo(BaseModel): def validate_metrics_url(cls, v): if v is None: return None - if isinstance(v, str): - HttpUrl(v) - return v + if not isinstance(v, str): + raise ValueError("metrics must be a string URL or None") + HttpUrl(v) + return v class ListProvidersResponse(BaseModel): diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index b860c4f88..a77c5cb18 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -50,7 +50,7 @@ class TestProviderMetrics: @pytest.mark.asyncio async def test_provider_with_invalid_metrics_in_config(self): - """Test that invalid metrics in config fails when access /providers.""" + """Test invalid metrics in config fails when access /providers.""" run_config = StackRunConfig( image_name="test_image", providers={ @@ -69,6 +69,27 @@ class TestProviderMetrics: with pytest.raises(ValidationError): await provider_impl.list_providers() + @pytest.mark.asyncio + async def test_provider_with_invalid_metrics_in_config2(self): + """Test invalid metrics2 in config fails when access /providers.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": 123}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + with pytest.raises(ValidationError): + await provider_impl.list_providers() + @pytest.mark.asyncio async def test_provider_without_metrics_in_config(self): """Test provider without metrics in config returns None."""