From 8563c76f8885b5eda9d7ae26f4d4b7fe2381c92c Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Mon, 30 Jun 2025 09:10:43 +0200 Subject: [PATCH 1/5] feat: add optional metrics under API /providers - add field "metrcis" under API "/providers" - each provider can config "metrics" in run.yaml in "config" - if no "metrics": /provider/ shows "metrics: null" in response - if has "metrics": /provider/ show result in response - if has "metrics" but is not string type and not httpurl, raise ValidationError - add unit tests for providers - update "docs" Signed-off-by: Wen Zhou --- docs/_static/llama-stack-spec.html | 3 + docs/_static/llama-stack-spec.yaml | 2 + llama_stack/apis/providers/providers.py | 12 +- llama_stack/distribution/providers.py | 6 +- tests/integration/providers/test_providers.py | 10 ++ tests/unit/distribution/test_distribution.py | 104 ++++++++++++++++ tests/unit/distribution/test_providers.py | 113 ++++++++++++++++++ 7 files changed, 248 insertions(+), 2 deletions(-) create mode 100644 tests/unit/distribution/test_providers.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 6794d1fbb..e620cf6dd 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11658,6 +11658,9 @@ } ] } + }, + "metrics": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 548c5a988..7f302a442 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8215,6 +8215,8 @@ components: - type: string - type: array - type: object + metrics: + type: string additionalProperties: false required: - api diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 4bc977bf1..3d7955970 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -6,7 +6,7 @@ from typing import Any, Protocol, runtime_checkable -from pydantic import BaseModel +from pydantic import BaseModel, HttpUrl, field_validator from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -19,6 +19,16 @@ class ProviderInfo(BaseModel): provider_type: str config: dict[str, Any] health: HealthResponse + metrics: str | None = None # define as string type than httpurl for openapi compatibility + + @field_validator("metrics") + @classmethod + def validate_metrics_url(cls, v): + if v is None: + return None + if isinstance(v, str): + HttpUrl(v) + return v class ListProvidersResponse(BaseModel): diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 7095ffd18..af92592e7 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -51,18 +51,22 @@ class ProviderImpl(Providers): # Skip providers that are not enabled if p.provider_id is None: continue + # Filter out "metrics" to be shown in config duplicated + metrics_url = p.config.get("metrics") + config = {k: v for k, v in p.config.items() if k != "metrics"} ret.append( ProviderInfo( api=api, provider_id=p.provider_id, provider_type=p.provider_type, - config=p.config, + config=config, health=providers_health.get(api, {}).get( p.provider_id, HealthResponse( status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" ), ), + metrics=metrics_url, ) ) diff --git a/tests/integration/providers/test_providers.py b/tests/integration/providers/test_providers.py index fc65e2a10..968b52724 100644 --- a/tests/integration/providers/test_providers.py +++ b/tests/integration/providers/test_providers.py @@ -19,3 +19,13 @@ class TestProviders: pid = provider.provider_id provider = llama_stack_client.providers.retrieve(pid) assert provider is not None + + @pytest.mark.asyncio + def test_providers_metrics_field(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): + """Test metrics field is in provider responses.""" + provider_list = llama_stack_client.providers.list() + assert provider_list is not None + assert len(provider_list) > 0 + + for provider in provider_list: + assert provider.metrics is None or isinstance(provider.metrics, str) diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index ae24602d7..b860c4f88 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -13,9 +13,113 @@ from pydantic import BaseModel, Field, ValidationError from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.providers.datatypes import ProviderSpec +class TestProviderMetrics: + """Test suite for provider metrics.""" + + @pytest.mark.asyncio + async def test_provider_with_valid_metrics_in_config(self): + """Test provider with valid metrics in config.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + response = await provider_impl.list_providers() + assert len(response.data) == 1 + + provider_info = response.data[0] + assert provider_info.provider_id == "test_provider" + assert provider_info.provider_type == "test_type" + assert provider_info.metrics == "http://localhost:9090/metrics" + assert "metrics" not in provider_info.config + assert provider_info.config["url"] == "http://localhost:8000" + + @pytest.mark.asyncio + async def test_provider_with_invalid_metrics_in_config(self): + """Test that invalid metrics in config fails when access /providers.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + with pytest.raises(ValidationError): + await provider_impl.list_providers() + + @pytest.mark.asyncio + async def test_provider_without_metrics_in_config(self): + """Test provider without metrics in config returns None.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"} + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + response = await provider_impl.list_providers() + assert len(response.data) == 1 + + provider_info = response.data[0] + assert provider_info.provider_id == "test_provider" + assert provider_info.provider_type == "test_type" + assert provider_info.metrics is None + assert "metrics" not in provider_info.config + assert provider_info.config["url"] == "http://localhost:8000" + + @pytest.mark.asyncio + async def test_inspect_provider_with_metrics(self): + """Test inspect_provider returns correct metrics info.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + # Test the inspect_provider API + provider_info = await provider_impl.inspect_provider("test_provider") + assert provider_info.provider_id == "test_provider" + assert provider_info.metrics == "http://localhost:9090/metrics" + assert "metrics" not in provider_info.config + + class SampleConfig(BaseModel): foo: str = Field( default="bar", diff --git a/tests/unit/distribution/test_providers.py b/tests/unit/distribution/test_providers.py new file mode 100644 index 000000000..0a81c4b31 --- /dev/null +++ b/tests/unit/distribution/test_providers.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack.apis.providers import ProviderInfo +from llama_stack.distribution.datatypes import Provider, StackRunConfig +from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig + + +class TestProviderImpl: + """Test suite for ProviderImpl class.""" + + @pytest.fixture + def mock_config(self): + """Create a mock configuration for testing.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider_with_metrics_url", + provider_type="test_type1", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ), + Provider( + provider_id="test_provider_no_metrics_url", + provider_type="test_type2", + config={"url": "http://localhost:8080"}, + ), + ] + }, + ) + return ProviderImplConfig(run_config=run_config) + + @pytest.fixture + def mock_deps(self): + """Create mock dependencies.""" + return {} + + @pytest.mark.asyncio + async def test_provider_info_structure(self, mock_config, mock_deps): + """Test ProviderInfo objects""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + response = await provider_impl.list_providers() + provider = response.data[0] + + # Check all required fields + assert hasattr(provider, "api") + assert isinstance(provider.api, str) + + assert hasattr(provider, "provider_id") + assert isinstance(provider.provider_id, str) + + assert hasattr(provider, "provider_type") + assert isinstance(provider.provider_type, str) + + assert hasattr(provider, "config") + assert isinstance(provider.config, dict) + + assert hasattr(provider, "health") + + assert provider.metrics is None or isinstance(provider.metrics, str) + + @pytest.mark.asyncio + async def test_list_providers_with_metrics(self, mock_config, mock_deps): + """Test list_providers includes metrics field.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + response = await provider_impl.list_providers() + + assert response is not None + assert len(response.data) == 2 + + # Check provider with metrics + provider1 = response.data[0] + assert isinstance(provider1, ProviderInfo) + assert provider1.provider_id == "test_provider_with_metrics_url" + assert provider1.metrics == "http://localhost:9090/metrics" + + # Check provider without metrics + provider2 = response.data[1] + assert isinstance(provider2, ProviderInfo) + assert provider2.provider_id == "test_provider_no_metrics_url" + assert provider2.metrics is None + + @pytest.mark.asyncio + async def test_inspect_provider_with_metrics(self, mock_config, mock_deps): + """Test inspect_provider includes metrics field.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + # Test provider with metrics + provider_info = await provider_impl.inspect_provider("test_provider_with_metrics_url") + assert provider_info.provider_id == "test_provider_with_metrics_url" + assert provider_info.metrics == "http://localhost:9090/metrics" + + # Test provider without metrics + provider_info = await provider_impl.inspect_provider("test_provider_no_metrics_url") + assert provider_info.provider_id == "test_provider_no_metrics_url" + assert provider_info.metrics is None + + @pytest.mark.asyncio + async def test_inspect_provider_not_found(self, mock_config, mock_deps): + """Test inspect_provider raises error for non-existent provider.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + with pytest.raises(ValueError, match="Provider nonexistent not found"): + await provider_impl.inspect_provider("nonexistent") From 0b051c037b005ab7d2331d68c3d2eb25638b0d28 Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Mon, 30 Jun 2025 10:12:48 +0200 Subject: [PATCH 2/5] update: add validation on non-string type Signed-off-by: Wen Zhou --- llama_stack/apis/providers/providers.py | 7 +++--- tests/unit/distribution/test_distribution.py | 23 +++++++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 3d7955970..b8bba6919 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -26,9 +26,10 @@ class ProviderInfo(BaseModel): def validate_metrics_url(cls, v): if v is None: return None - if isinstance(v, str): - HttpUrl(v) - return v + if not isinstance(v, str): + raise ValueError("metrics must be a string URL or None") + HttpUrl(v) + return v class ListProvidersResponse(BaseModel): diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index b860c4f88..a77c5cb18 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -50,7 +50,7 @@ class TestProviderMetrics: @pytest.mark.asyncio async def test_provider_with_invalid_metrics_in_config(self): - """Test that invalid metrics in config fails when access /providers.""" + """Test invalid metrics in config fails when access /providers.""" run_config = StackRunConfig( image_name="test_image", providers={ @@ -69,6 +69,27 @@ class TestProviderMetrics: with pytest.raises(ValidationError): await provider_impl.list_providers() + @pytest.mark.asyncio + async def test_provider_with_invalid_metrics_in_config2(self): + """Test invalid metrics2 in config fails when access /providers.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": 123}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + with pytest.raises(ValidationError): + await provider_impl.list_providers() + @pytest.mark.asyncio async def test_provider_without_metrics_in_config(self): """Test provider without metrics in config returns None.""" From 2ce940259362e0b916c370ab9eeeb3c411867a9c Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Thu, 3 Jul 2025 17:11:52 +0200 Subject: [PATCH 3/5] update: code review - change type from str to Field with description - remove test from test_distrubution.py, keep in test_providers.py Signed-off-by: Wen Zhou --- llama_stack/apis/providers/providers.py | 6 +- tests/unit/distribution/test_distribution.py | 125 ------------------- tests/unit/distribution/test_providers.py | 26 ++++ 3 files changed, 30 insertions(+), 127 deletions(-) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index b8bba6919..8c1a8d73e 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -6,7 +6,7 @@ from typing import Any, Protocol, runtime_checkable -from pydantic import BaseModel, HttpUrl, field_validator +from pydantic import BaseModel, Field, HttpUrl, field_validator from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -19,7 +19,9 @@ class ProviderInfo(BaseModel): provider_type: str config: dict[str, Any] health: HealthResponse - metrics: str | None = None # define as string type than httpurl for openapi compatibility + metrics: str | None = Field( + default=None, description="endpoint for metrics from providers. Must be a valid HTTP URL if provided." + ) @field_validator("metrics") @classmethod diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index a77c5cb18..ae24602d7 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -13,134 +13,9 @@ from pydantic import BaseModel, Field, ValidationError from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.providers.datatypes import ProviderSpec -class TestProviderMetrics: - """Test suite for provider metrics.""" - - @pytest.mark.asyncio - async def test_provider_with_valid_metrics_in_config(self): - """Test provider with valid metrics in config.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - response = await provider_impl.list_providers() - assert len(response.data) == 1 - - provider_info = response.data[0] - assert provider_info.provider_id == "test_provider" - assert provider_info.provider_type == "test_type" - assert provider_info.metrics == "http://localhost:9090/metrics" - assert "metrics" not in provider_info.config - assert provider_info.config["url"] == "http://localhost:8000" - - @pytest.mark.asyncio - async def test_provider_with_invalid_metrics_in_config(self): - """Test invalid metrics in config fails when access /providers.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - with pytest.raises(ValidationError): - await provider_impl.list_providers() - - @pytest.mark.asyncio - async def test_provider_with_invalid_metrics_in_config2(self): - """Test invalid metrics2 in config fails when access /providers.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": 123}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - with pytest.raises(ValidationError): - await provider_impl.list_providers() - - @pytest.mark.asyncio - async def test_provider_without_metrics_in_config(self): - """Test provider without metrics in config returns None.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"} - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - response = await provider_impl.list_providers() - assert len(response.data) == 1 - - provider_info = response.data[0] - assert provider_info.provider_id == "test_provider" - assert provider_info.provider_type == "test_type" - assert provider_info.metrics is None - assert "metrics" not in provider_info.config - assert provider_info.config["url"] == "http://localhost:8000" - - @pytest.mark.asyncio - async def test_inspect_provider_with_metrics(self): - """Test inspect_provider returns correct metrics info.""" - run_config = StackRunConfig( - image_name="test_image", - providers={ - "inference": [ - Provider( - provider_id="test_provider", - provider_type="test_type", - config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, - ) - ] - }, - ) - provider_config = ProviderImplConfig(run_config=run_config) - provider_impl = ProviderImpl(provider_config, {}) - - # Test the inspect_provider API - provider_info = await provider_impl.inspect_provider("test_provider") - assert provider_info.provider_id == "test_provider" - assert provider_info.metrics == "http://localhost:9090/metrics" - assert "metrics" not in provider_info.config - - class SampleConfig(BaseModel): foo: str = Field( default="bar", diff --git a/tests/unit/distribution/test_providers.py b/tests/unit/distribution/test_providers.py index 0a81c4b31..4f1ee5cb9 100644 --- a/tests/unit/distribution/test_providers.py +++ b/tests/unit/distribution/test_providers.py @@ -6,6 +6,7 @@ import pytest +from pydantic import ValidationError from llama_stack.apis.providers import ProviderInfo from llama_stack.distribution.datatypes import Provider, StackRunConfig @@ -37,6 +38,23 @@ class TestProviderImpl: ) return ProviderImplConfig(run_config=run_config) + @pytest.fixture + def mock_config_malformed_metrics(self): + """Create a mock configuration with invalid metrics URL for testing.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider_malformed_metrics", + provider_type="test_type3", + config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, + ), + ] + }, + ) + return ProviderImplConfig(run_config=run_config) + @pytest.fixture def mock_deps(self): """Create mock dependencies.""" @@ -111,3 +129,11 @@ class TestProviderImpl: with pytest.raises(ValueError, match="Provider nonexistent not found"): await provider_impl.inspect_provider("nonexistent") + + @pytest.mark.asyncio + async def test_inspect_provider_malformed_metrics(self, mock_config_malformed_metrics, mock_deps): + """Test inspect_provider with invalid metrics URL raises validation error.""" + provider_impl = ProviderImpl(mock_config_malformed_metrics, mock_deps) + + with pytest.raises(ValidationError): + await provider_impl.inspect_provider("test_provider_malformed_metrics") From 201824946ec662d17c030b47eb3046ee2ff2da88 Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Tue, 15 Jul 2025 13:29:25 +0200 Subject: [PATCH 4/5] update: code review - format for error message - wrapper into a handler for httpurl validation Signed-off-by: Wen Zhou --- llama_stack/apis/providers/providers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 8c1a8d73e..34b3e6e44 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -7,6 +7,7 @@ from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field, HttpUrl, field_validator +from pydantic_core import PydanticCustomError from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -20,7 +21,7 @@ class ProviderInfo(BaseModel): config: dict[str, Any] health: HealthResponse metrics: str | None = Field( - default=None, description="endpoint for metrics from providers. Must be a valid HTTP URL if provided." + default=None, description="Endpoint for metrics from providers. Must be a valid HTTP URL if provided." ) @field_validator("metrics") @@ -29,9 +30,12 @@ class ProviderInfo(BaseModel): if v is None: return None if not isinstance(v, str): - raise ValueError("metrics must be a string URL or None") - HttpUrl(v) - return v + raise ValueError("'metrics' must be a string URL or None") + try: + HttpUrl(v) # Validate the URL + return v + except (PydanticCustomError, ValueError) as e: + raise ValueError(f"'metrics' must be a valid HTTP or HTTPS URL: {str(e)}") from e class ListProvidersResponse(BaseModel): From daeed865dffbc4632f7acb02666f158227ef1af9 Mon Sep 17 00:00:00 2001 From: Wen Zhou Date: Tue, 15 Jul 2025 13:47:49 +0200 Subject: [PATCH 5/5] fix: rebase error Signed-off-by: Wen Zhou --- tests/integration/providers/test_providers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/providers/test_providers.py b/tests/integration/providers/test_providers.py index 968b52724..36ac17e5c 100644 --- a/tests/integration/providers/test_providers.py +++ b/tests/integration/providers/test_providers.py @@ -20,7 +20,6 @@ class TestProviders: provider = llama_stack_client.providers.retrieve(pid) assert provider is not None - @pytest.mark.asyncio def test_providers_metrics_field(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): """Test metrics field is in provider responses.""" provider_list = llama_stack_client.providers.list()