diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 6794d1fbb..e620cf6dd 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -11658,6 +11658,9 @@ } ] } + }, + "metrics": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 548c5a988..7f302a442 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8215,6 +8215,8 @@ components: - type: string - type: array - type: object + metrics: + type: string additionalProperties: false required: - api diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 4bc977bf1..3d7955970 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -6,7 +6,7 @@ from typing import Any, Protocol, runtime_checkable -from pydantic import BaseModel +from pydantic import BaseModel, HttpUrl, field_validator from llama_stack.providers.datatypes import HealthResponse from llama_stack.schema_utils import json_schema_type, webmethod @@ -19,6 +19,16 @@ class ProviderInfo(BaseModel): provider_type: str config: dict[str, Any] health: HealthResponse + metrics: str | None = None # define as string type than httpurl for openapi compatibility + + @field_validator("metrics") + @classmethod + def validate_metrics_url(cls, v): + if v is None: + return None + if isinstance(v, str): + HttpUrl(v) + return v class ListProvidersResponse(BaseModel): diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 7095ffd18..af92592e7 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -51,18 +51,22 @@ class ProviderImpl(Providers): # Skip providers that are not enabled if p.provider_id is None: continue + # Filter out "metrics" to be shown in config duplicated + metrics_url = p.config.get("metrics") + config = {k: v for k, v in p.config.items() if k != "metrics"} ret.append( ProviderInfo( api=api, provider_id=p.provider_id, provider_type=p.provider_type, - config=p.config, + config=config, health=providers_health.get(api, {}).get( p.provider_id, HealthResponse( status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check" ), ), + metrics=metrics_url, ) ) diff --git a/tests/integration/providers/test_providers.py b/tests/integration/providers/test_providers.py index fc65e2a10..968b52724 100644 --- a/tests/integration/providers/test_providers.py +++ b/tests/integration/providers/test_providers.py @@ -19,3 +19,13 @@ class TestProviders: pid = provider.provider_id provider = llama_stack_client.providers.retrieve(pid) assert provider is not None + + @pytest.mark.asyncio + def test_providers_metrics_field(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): + """Test metrics field is in provider responses.""" + provider_list = llama_stack_client.providers.list() + assert provider_list is not None + assert len(provider_list) > 0 + + for provider in provider_list: + assert provider.metrics is None or isinstance(provider.metrics, str) diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index ae24602d7..b860c4f88 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -13,9 +13,113 @@ from pydantic import BaseModel, Field, ValidationError from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.providers.datatypes import ProviderSpec +class TestProviderMetrics: + """Test suite for provider metrics.""" + + @pytest.mark.asyncio + async def test_provider_with_valid_metrics_in_config(self): + """Test provider with valid metrics in config.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + response = await provider_impl.list_providers() + assert len(response.data) == 1 + + provider_info = response.data[0] + assert provider_info.provider_id == "test_provider" + assert provider_info.provider_type == "test_type" + assert provider_info.metrics == "http://localhost:9090/metrics" + assert "metrics" not in provider_info.config + assert provider_info.config["url"] == "http://localhost:8000" + + @pytest.mark.asyncio + async def test_provider_with_invalid_metrics_in_config(self): + """Test that invalid metrics in config fails when access /providers.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + with pytest.raises(ValidationError): + await provider_impl.list_providers() + + @pytest.mark.asyncio + async def test_provider_without_metrics_in_config(self): + """Test provider without metrics in config returns None.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"} + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + response = await provider_impl.list_providers() + assert len(response.data) == 1 + + provider_info = response.data[0] + assert provider_info.provider_id == "test_provider" + assert provider_info.provider_type == "test_type" + assert provider_info.metrics is None + assert "metrics" not in provider_info.config + assert provider_info.config["url"] == "http://localhost:8000" + + @pytest.mark.asyncio + async def test_inspect_provider_with_metrics(self): + """Test inspect_provider returns correct metrics info.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider", + provider_type="test_type", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ) + ] + }, + ) + provider_config = ProviderImplConfig(run_config=run_config) + provider_impl = ProviderImpl(provider_config, {}) + + # Test the inspect_provider API + provider_info = await provider_impl.inspect_provider("test_provider") + assert provider_info.provider_id == "test_provider" + assert provider_info.metrics == "http://localhost:9090/metrics" + assert "metrics" not in provider_info.config + + class SampleConfig(BaseModel): foo: str = Field( default="bar", diff --git a/tests/unit/distribution/test_providers.py b/tests/unit/distribution/test_providers.py new file mode 100644 index 000000000..0a81c4b31 --- /dev/null +++ b/tests/unit/distribution/test_providers.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack.apis.providers import ProviderInfo +from llama_stack.distribution.datatypes import Provider, StackRunConfig +from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig + + +class TestProviderImpl: + """Test suite for ProviderImpl class.""" + + @pytest.fixture + def mock_config(self): + """Create a mock configuration for testing.""" + run_config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="test_provider_with_metrics_url", + provider_type="test_type1", + config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"}, + ), + Provider( + provider_id="test_provider_no_metrics_url", + provider_type="test_type2", + config={"url": "http://localhost:8080"}, + ), + ] + }, + ) + return ProviderImplConfig(run_config=run_config) + + @pytest.fixture + def mock_deps(self): + """Create mock dependencies.""" + return {} + + @pytest.mark.asyncio + async def test_provider_info_structure(self, mock_config, mock_deps): + """Test ProviderInfo objects""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + response = await provider_impl.list_providers() + provider = response.data[0] + + # Check all required fields + assert hasattr(provider, "api") + assert isinstance(provider.api, str) + + assert hasattr(provider, "provider_id") + assert isinstance(provider.provider_id, str) + + assert hasattr(provider, "provider_type") + assert isinstance(provider.provider_type, str) + + assert hasattr(provider, "config") + assert isinstance(provider.config, dict) + + assert hasattr(provider, "health") + + assert provider.metrics is None or isinstance(provider.metrics, str) + + @pytest.mark.asyncio + async def test_list_providers_with_metrics(self, mock_config, mock_deps): + """Test list_providers includes metrics field.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + response = await provider_impl.list_providers() + + assert response is not None + assert len(response.data) == 2 + + # Check provider with metrics + provider1 = response.data[0] + assert isinstance(provider1, ProviderInfo) + assert provider1.provider_id == "test_provider_with_metrics_url" + assert provider1.metrics == "http://localhost:9090/metrics" + + # Check provider without metrics + provider2 = response.data[1] + assert isinstance(provider2, ProviderInfo) + assert provider2.provider_id == "test_provider_no_metrics_url" + assert provider2.metrics is None + + @pytest.mark.asyncio + async def test_inspect_provider_with_metrics(self, mock_config, mock_deps): + """Test inspect_provider includes metrics field.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + # Test provider with metrics + provider_info = await provider_impl.inspect_provider("test_provider_with_metrics_url") + assert provider_info.provider_id == "test_provider_with_metrics_url" + assert provider_info.metrics == "http://localhost:9090/metrics" + + # Test provider without metrics + provider_info = await provider_impl.inspect_provider("test_provider_no_metrics_url") + assert provider_info.provider_id == "test_provider_no_metrics_url" + assert provider_info.metrics is None + + @pytest.mark.asyncio + async def test_inspect_provider_not_found(self, mock_config, mock_deps): + """Test inspect_provider raises error for non-existent provider.""" + provider_impl = ProviderImpl(mock_config, mock_deps) + + with pytest.raises(ValueError, match="Provider nonexistent not found"): + await provider_impl.inspect_provider("nonexistent")