mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
feat: add optional metrics under API /providers
- add field "metrcis" under API "/providers" - each provider can config "metrics" in run.yaml in "config" - if no "metrics": /provider/<provider_id> shows "metrics: null" in response - if has "metrics": /provider/<provider_id> show result in response - if has "metrics" but is not string type and not httpurl, raise ValidationError - add unit tests for providers - update "docs" Signed-off-by: Wen Zhou <wenzhou@redhat.com>
This commit is contained in:
parent
5400a2e2b1
commit
8563c76f88
7 changed files with 248 additions and 2 deletions
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -11658,6 +11658,9 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -8215,6 +8215,8 @@ components:
|
||||||
- type: string
|
- type: string
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
|
metrics:
|
||||||
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- api
|
- api
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, HttpUrl, field_validator
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import HealthResponse
|
from llama_stack.providers.datatypes import HealthResponse
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -19,6 +19,16 @@ class ProviderInfo(BaseModel):
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: dict[str, Any]
|
config: dict[str, Any]
|
||||||
health: HealthResponse
|
health: HealthResponse
|
||||||
|
metrics: str | None = None # define as string type than httpurl for openapi compatibility
|
||||||
|
|
||||||
|
@field_validator("metrics")
|
||||||
|
@classmethod
|
||||||
|
def validate_metrics_url(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, str):
|
||||||
|
HttpUrl(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
|
|
|
@ -51,18 +51,22 @@ class ProviderImpl(Providers):
|
||||||
# Skip providers that are not enabled
|
# Skip providers that are not enabled
|
||||||
if p.provider_id is None:
|
if p.provider_id is None:
|
||||||
continue
|
continue
|
||||||
|
# Filter out "metrics" to be shown in config duplicated
|
||||||
|
metrics_url = p.config.get("metrics")
|
||||||
|
config = {k: v for k, v in p.config.items() if k != "metrics"}
|
||||||
ret.append(
|
ret.append(
|
||||||
ProviderInfo(
|
ProviderInfo(
|
||||||
api=api,
|
api=api,
|
||||||
provider_id=p.provider_id,
|
provider_id=p.provider_id,
|
||||||
provider_type=p.provider_type,
|
provider_type=p.provider_type,
|
||||||
config=p.config,
|
config=config,
|
||||||
health=providers_health.get(api, {}).get(
|
health=providers_health.get(api, {}).get(
|
||||||
p.provider_id,
|
p.provider_id,
|
||||||
HealthResponse(
|
HealthResponse(
|
||||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
metrics=metrics_url,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,3 +19,13 @@ class TestProviders:
|
||||||
pid = provider.provider_id
|
pid = provider.provider_id
|
||||||
provider = llama_stack_client.providers.retrieve(pid)
|
provider = llama_stack_client.providers.retrieve(pid)
|
||||||
assert provider is not None
|
assert provider is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
def test_providers_metrics_field(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test metrics field is in provider responses."""
|
||||||
|
provider_list = llama_stack_client.providers.list()
|
||||||
|
assert provider_list is not None
|
||||||
|
assert len(provider_list) > 0
|
||||||
|
|
||||||
|
for provider in provider_list:
|
||||||
|
assert provider.metrics is None or isinstance(provider.metrics, str)
|
||||||
|
|
|
@ -13,9 +13,113 @@ from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||||
from llama_stack.providers.datatypes import ProviderSpec
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderMetrics:
|
||||||
|
"""Test suite for provider metrics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_with_valid_metrics_in_config(self):
|
||||||
|
"""Test provider with valid metrics in config."""
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_type="test_type",
|
||||||
|
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
provider_config = ProviderImplConfig(run_config=run_config)
|
||||||
|
provider_impl = ProviderImpl(provider_config, {})
|
||||||
|
|
||||||
|
response = await provider_impl.list_providers()
|
||||||
|
assert len(response.data) == 1
|
||||||
|
|
||||||
|
provider_info = response.data[0]
|
||||||
|
assert provider_info.provider_id == "test_provider"
|
||||||
|
assert provider_info.provider_type == "test_type"
|
||||||
|
assert provider_info.metrics == "http://localhost:9090/metrics"
|
||||||
|
assert "metrics" not in provider_info.config
|
||||||
|
assert provider_info.config["url"] == "http://localhost:8000"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_with_invalid_metrics_in_config(self):
|
||||||
|
"""Test that invalid metrics in config fails when access /providers."""
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_type="test_type",
|
||||||
|
config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
provider_config = ProviderImplConfig(run_config=run_config)
|
||||||
|
provider_impl = ProviderImpl(provider_config, {})
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await provider_impl.list_providers()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_without_metrics_in_config(self):
|
||||||
|
"""Test provider without metrics in config returns None."""
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
provider_config = ProviderImplConfig(run_config=run_config)
|
||||||
|
provider_impl = ProviderImpl(provider_config, {})
|
||||||
|
|
||||||
|
response = await provider_impl.list_providers()
|
||||||
|
assert len(response.data) == 1
|
||||||
|
|
||||||
|
provider_info = response.data[0]
|
||||||
|
assert provider_info.provider_id == "test_provider"
|
||||||
|
assert provider_info.provider_type == "test_type"
|
||||||
|
assert provider_info.metrics is None
|
||||||
|
assert "metrics" not in provider_info.config
|
||||||
|
assert provider_info.config["url"] == "http://localhost:8000"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inspect_provider_with_metrics(self):
|
||||||
|
"""Test inspect_provider returns correct metrics info."""
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_type="test_type",
|
||||||
|
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
provider_config = ProviderImplConfig(run_config=run_config)
|
||||||
|
provider_impl = ProviderImpl(provider_config, {})
|
||||||
|
|
||||||
|
# Test the inspect_provider API
|
||||||
|
provider_info = await provider_impl.inspect_provider("test_provider")
|
||||||
|
assert provider_info.provider_id == "test_provider"
|
||||||
|
assert provider_info.metrics == "http://localhost:9090/metrics"
|
||||||
|
assert "metrics" not in provider_info.config
|
||||||
|
|
||||||
|
|
||||||
class SampleConfig(BaseModel):
|
class SampleConfig(BaseModel):
|
||||||
foo: str = Field(
|
foo: str = Field(
|
||||||
default="bar",
|
default="bar",
|
||||||
|
|
113
tests/unit/distribution/test_providers.py
Normal file
113
tests/unit/distribution/test_providers.py
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.providers import ProviderInfo
|
||||||
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderImpl:
|
||||||
|
"""Test suite for ProviderImpl class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config(self):
|
||||||
|
"""Create a mock configuration for testing."""
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider_with_metrics_url",
|
||||||
|
provider_type="test_type1",
|
||||||
|
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider_no_metrics_url",
|
||||||
|
provider_type="test_type2",
|
||||||
|
config={"url": "http://localhost:8080"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return ProviderImplConfig(run_config=run_config)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_deps(self):
|
||||||
|
"""Create mock dependencies."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_info_structure(self, mock_config, mock_deps):
|
||||||
|
"""Test ProviderInfo objects"""
|
||||||
|
provider_impl = ProviderImpl(mock_config, mock_deps)
|
||||||
|
|
||||||
|
response = await provider_impl.list_providers()
|
||||||
|
provider = response.data[0]
|
||||||
|
|
||||||
|
# Check all required fields
|
||||||
|
assert hasattr(provider, "api")
|
||||||
|
assert isinstance(provider.api, str)
|
||||||
|
|
||||||
|
assert hasattr(provider, "provider_id")
|
||||||
|
assert isinstance(provider.provider_id, str)
|
||||||
|
|
||||||
|
assert hasattr(provider, "provider_type")
|
||||||
|
assert isinstance(provider.provider_type, str)
|
||||||
|
|
||||||
|
assert hasattr(provider, "config")
|
||||||
|
assert isinstance(provider.config, dict)
|
||||||
|
|
||||||
|
assert hasattr(provider, "health")
|
||||||
|
|
||||||
|
assert provider.metrics is None or isinstance(provider.metrics, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_providers_with_metrics(self, mock_config, mock_deps):
|
||||||
|
"""Test list_providers includes metrics field."""
|
||||||
|
provider_impl = ProviderImpl(mock_config, mock_deps)
|
||||||
|
|
||||||
|
response = await provider_impl.list_providers()
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert len(response.data) == 2
|
||||||
|
|
||||||
|
# Check provider with metrics
|
||||||
|
provider1 = response.data[0]
|
||||||
|
assert isinstance(provider1, ProviderInfo)
|
||||||
|
assert provider1.provider_id == "test_provider_with_metrics_url"
|
||||||
|
assert provider1.metrics == "http://localhost:9090/metrics"
|
||||||
|
|
||||||
|
# Check provider without metrics
|
||||||
|
provider2 = response.data[1]
|
||||||
|
assert isinstance(provider2, ProviderInfo)
|
||||||
|
assert provider2.provider_id == "test_provider_no_metrics_url"
|
||||||
|
assert provider2.metrics is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inspect_provider_with_metrics(self, mock_config, mock_deps):
|
||||||
|
"""Test inspect_provider includes metrics field."""
|
||||||
|
provider_impl = ProviderImpl(mock_config, mock_deps)
|
||||||
|
|
||||||
|
# Test provider with metrics
|
||||||
|
provider_info = await provider_impl.inspect_provider("test_provider_with_metrics_url")
|
||||||
|
assert provider_info.provider_id == "test_provider_with_metrics_url"
|
||||||
|
assert provider_info.metrics == "http://localhost:9090/metrics"
|
||||||
|
|
||||||
|
# Test provider without metrics
|
||||||
|
provider_info = await provider_impl.inspect_provider("test_provider_no_metrics_url")
|
||||||
|
assert provider_info.provider_id == "test_provider_no_metrics_url"
|
||||||
|
assert provider_info.metrics is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inspect_provider_not_found(self, mock_config, mock_deps):
|
||||||
|
"""Test inspect_provider raises error for non-existent provider."""
|
||||||
|
provider_impl = ProviderImpl(mock_config, mock_deps)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Provider nonexistent not found"):
|
||||||
|
await provider_impl.inspect_provider("nonexistent")
|
Loading…
Add table
Add a link
Reference in a new issue