diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 38e53a438..812c360bd 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11661,6 +11661,9 @@ } ] } + }, + "metrics": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0df60ddf4..32b537782 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8217,6 +8217,8 @@ components: - type: string - type: array - type: object + metrics: + type: string additionalProperties: false required: - api diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 4bc977bf1..34b3e6e44 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -6,7 +6,8 @@ from typing import Any, Protocol, runtime_checkable -from pydantic import BaseModel +from pydantic import BaseModel, Field, HttpUrl, field_validator +from pydantic_core import PydanticCustomError from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -19,6 +20,22 @@ class ProviderInfo(BaseModel): provider_type: str config: dict[str, Any] health: HealthResponse + metrics: str | None = Field( + default=None, description="Endpoint for metrics from providers. Must be a valid HTTP URL if provided." + ) + + @field_validator("metrics") + @classmethod + def validate_metrics_url(cls, v): + if v is None: + return None + if not isinstance(v, str): + raise ValueError("'metrics' must be a string URL or None") + try: + HttpUrl(v) # Validate the URL + return v + except (PydanticCustomError, ValueError) as e: + raise ValueError(f"'metrics' must be a valid HTTP or HTTPS URL: {str(e)}") from e class ListProvidersResponse(BaseModel): diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 7095ffd18..af92592e7 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -51,18 +51,22 @@ class ProviderImpl(Providers): # Skip providers that are not enabled if p.provider_id is None: continue + # Filter out "metrics" to be shown in config duplicated + metrics_url = p.config.get("metrics") + config = {k: v for k, v in p.config.items() if k != "metrics"} ret.append( ProviderInfo( api=api, provider_id=p.provider_id, provider_type=p.provider_type, - config=p.config, + config=config, health=providers_health.get(api, {}).get( p.provider_id, HealthResponse( status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" ), ), + metrics=metrics_url, ) ) diff --git a/tests/integration/providers/test_providers.py b/tests/integration/providers/test_providers.py index fc65e2a10..36ac17e5c 100644 --- a/tests/integration/providers/test_providers.py +++ b/tests/integration/providers/test_providers.py @@ -19,3 +19,12 @@ class TestProviders: pid = provider.provider_id provider = llama_stack_client.providers.retrieve(pid) assert provider is not None + + def test_providers_metrics_field(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): + """Test metrics field is in provider responses.""" + provider_list = llama_stack_client.providers.list() + assert provider_list is not None + assert len(provider_list) > 0 + + for provider in provider_list: + assert provider.metrics is None or isinstance(provider.metrics, str) diff --git a/tests/unit/distribution/test_providers.py b/tests/unit/distribution/test_providers.py new file mode 100644 index 000000000..4f1ee5cb9 --- /dev/null +++ b/tests/unit/distribution/test_providers.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest +from pydantic import ValidationError + +from llama_stack.apis.providers import ProviderInfo +from llama_stack.distribution.datatypes import Provider, StackRunConfig +from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig + + +class TestProviderImpl: + """Test suite for ProviderImpl class.""" + + @pytest.fixture + def mock_config(self): + """Create a mock configuration for testing.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider_with_metrics_url", + provider_type="test_type1", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ), + Provider( + provider_id="test_provider_no_metrics_url", + provider_type="test_type2", + config={"url": "http://localhost:8080"}, + ), + ] + }, + ) + return ProviderImplConfig(run_config=run_config) + + @pytest.fixture + def mock_config_malformed_metrics(self): + """Create a mock configuration with invalid metrics URL for testing.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider_malformed_metrics", + provider_type="test_type3", + config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, + ), + ] + }, + ) + return ProviderImplConfig(run_config=run_config) + + @pytest.fixture + def mock_deps(self): + """Create mock dependencies.""" + return {} + + @pytest.mark.asyncio + async def test_provider_info_structure(self, mock_config, mock_deps): + """Test ProviderInfo objects""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + response = await provider_impl.list_providers() + provider = response.data[0] + + # Check all required fields + assert hasattr(provider, "api") + assert isinstance(provider.api, str) + + assert hasattr(provider, "provider_id") + assert isinstance(provider.provider_id, str) + + assert hasattr(provider, "provider_type") + assert isinstance(provider.provider_type, str) + + assert hasattr(provider, "config") + assert isinstance(provider.config, dict) + + assert hasattr(provider, "health") + + assert provider.metrics is None or isinstance(provider.metrics, str) + + @pytest.mark.asyncio + async def test_list_providers_with_metrics(self, mock_config, mock_deps): + """Test list_providers includes metrics field.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + response = await provider_impl.list_providers() + + assert response is not None + assert len(response.data) == 2 + + # Check provider with metrics + provider1 = response.data[0] + assert isinstance(provider1, ProviderInfo) + assert provider1.provider_id == "test_provider_with_metrics_url" + assert provider1.metrics == "http://localhost:9090/metrics" + + # Check provider without metrics + provider2 = response.data[1] + assert isinstance(provider2, ProviderInfo) + assert provider2.provider_id == "test_provider_no_metrics_url" + assert provider2.metrics is None + + @pytest.mark.asyncio + async def test_inspect_provider_with_metrics(self, mock_config, mock_deps): + """Test inspect_provider includes metrics field.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + # Test provider with metrics + provider_info = await provider_impl.inspect_provider("test_provider_with_metrics_url") + assert provider_info.provider_id == "test_provider_with_metrics_url" + assert provider_info.metrics == "http://localhost:9090/metrics" + + # Test provider without metrics + provider_info = await provider_impl.inspect_provider("test_provider_no_metrics_url") + assert provider_info.provider_id == "test_provider_no_metrics_url" + assert provider_info.metrics is None + + @pytest.mark.asyncio + async def test_inspect_provider_not_found(self, mock_config, mock_deps): + """Test inspect_provider raises error for non-existent provider.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + with pytest.raises(ValueError, match="Provider nonexistent not found"): + await provider_impl.inspect_provider("nonexistent") + + @pytest.mark.asyncio + async def test_inspect_provider_malformed_metrics(self, mock_config_malformed_metrics, mock_deps): + """Test inspect_provider with invalid metrics URL raises validation error.""" + provider_impl = ProviderImpl(mock_config_malformed_metrics, mock_deps) + + with pytest.raises(ValidationError): + await provider_impl.inspect_provider("test_provider_malformed_metrics")