mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
update: code review
- change type from str to Field with description - remove test from test_distrubution.py, keep in test_providers.py Signed-off-by: Wen Zhou <wenzhou@redhat.com>
This commit is contained in:
parent
0b051c037b
commit
2ce9402593
3 changed files with 30 additions and 127 deletions
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, HttpUrl, field_validator
|
from pydantic import BaseModel, Field, HttpUrl, field_validator
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import HealthResponse
|
from llama_stack.providers.datatypes import HealthResponse
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -19,7 +19,9 @@ class ProviderInfo(BaseModel):
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: dict[str, Any]
|
config: dict[str, Any]
|
||||||
health: HealthResponse
|
health: HealthResponse
|
||||||
metrics: str | None = None # define as string type than httpurl for openapi compatibility
|
metrics: str | None = Field(
|
||||||
|
default=None, description="endpoint for metrics from providers. Must be a valid HTTP URL if provided."
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("metrics")
|
@field_validator("metrics")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -13,134 +13,9 @@ from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
|
||||||
from llama_stack.providers.datatypes import ProviderSpec
|
from llama_stack.providers.datatypes import ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
class TestProviderMetrics:
|
|
||||||
"""Test suite for provider metrics."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_provider_with_valid_metrics_in_config(self):
|
|
||||||
"""Test provider with valid metrics in config."""
|
|
||||||
run_config = StackRunConfig(
|
|
||||||
image_name="test_image",
|
|
||||||
providers={
|
|
||||||
"inference": [
|
|
||||||
Provider(
|
|
||||||
provider_id="test_provider",
|
|
||||||
provider_type="test_type",
|
|
||||||
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
provider_config = ProviderImplConfig(run_config=run_config)
|
|
||||||
provider_impl = ProviderImpl(provider_config, {})
|
|
||||||
|
|
||||||
response = await provider_impl.list_providers()
|
|
||||||
assert len(response.data) == 1
|
|
||||||
|
|
||||||
provider_info = response.data[0]
|
|
||||||
assert provider_info.provider_id == "test_provider"
|
|
||||||
assert provider_info.provider_type == "test_type"
|
|
||||||
assert provider_info.metrics == "http://localhost:9090/metrics"
|
|
||||||
assert "metrics" not in provider_info.config
|
|
||||||
assert provider_info.config["url"] == "http://localhost:8000"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_provider_with_invalid_metrics_in_config(self):
|
|
||||||
"""Test invalid metrics in config fails when access /providers."""
|
|
||||||
run_config = StackRunConfig(
|
|
||||||
image_name="test_image",
|
|
||||||
providers={
|
|
||||||
"inference": [
|
|
||||||
Provider(
|
|
||||||
provider_id="test_provider",
|
|
||||||
provider_type="test_type",
|
|
||||||
config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
provider_config = ProviderImplConfig(run_config=run_config)
|
|
||||||
provider_impl = ProviderImpl(provider_config, {})
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
await provider_impl.list_providers()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_provider_with_invalid_metrics_in_config2(self):
|
|
||||||
"""Test invalid metrics2 in config fails when access /providers."""
|
|
||||||
run_config = StackRunConfig(
|
|
||||||
image_name="test_image",
|
|
||||||
providers={
|
|
||||||
"inference": [
|
|
||||||
Provider(
|
|
||||||
provider_id="test_provider",
|
|
||||||
provider_type="test_type",
|
|
||||||
config={"url": "http://localhost:8000", "metrics": 123},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
provider_config = ProviderImplConfig(run_config=run_config)
|
|
||||||
provider_impl = ProviderImpl(provider_config, {})
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
await provider_impl.list_providers()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_provider_without_metrics_in_config(self):
|
|
||||||
"""Test provider without metrics in config returns None."""
|
|
||||||
run_config = StackRunConfig(
|
|
||||||
image_name="test_image",
|
|
||||||
providers={
|
|
||||||
"inference": [
|
|
||||||
Provider(
|
|
||||||
provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
provider_config = ProviderImplConfig(run_config=run_config)
|
|
||||||
provider_impl = ProviderImpl(provider_config, {})
|
|
||||||
|
|
||||||
response = await provider_impl.list_providers()
|
|
||||||
assert len(response.data) == 1
|
|
||||||
|
|
||||||
provider_info = response.data[0]
|
|
||||||
assert provider_info.provider_id == "test_provider"
|
|
||||||
assert provider_info.provider_type == "test_type"
|
|
||||||
assert provider_info.metrics is None
|
|
||||||
assert "metrics" not in provider_info.config
|
|
||||||
assert provider_info.config["url"] == "http://localhost:8000"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_inspect_provider_with_metrics(self):
|
|
||||||
"""Test inspect_provider returns correct metrics info."""
|
|
||||||
run_config = StackRunConfig(
|
|
||||||
image_name="test_image",
|
|
||||||
providers={
|
|
||||||
"inference": [
|
|
||||||
Provider(
|
|
||||||
provider_id="test_provider",
|
|
||||||
provider_type="test_type",
|
|
||||||
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
provider_config = ProviderImplConfig(run_config=run_config)
|
|
||||||
provider_impl = ProviderImpl(provider_config, {})
|
|
||||||
|
|
||||||
# Test the inspect_provider API
|
|
||||||
provider_info = await provider_impl.inspect_provider("test_provider")
|
|
||||||
assert provider_info.provider_id == "test_provider"
|
|
||||||
assert provider_info.metrics == "http://localhost:9090/metrics"
|
|
||||||
assert "metrics" not in provider_info.config
|
|
||||||
|
|
||||||
|
|
||||||
class SampleConfig(BaseModel):
|
class SampleConfig(BaseModel):
|
||||||
foo: str = Field(
|
foo: str = Field(
|
||||||
default="bar",
|
default="bar",
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from llama_stack.apis.providers import ProviderInfo
|
from llama_stack.apis.providers import ProviderInfo
|
||||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
@ -37,6 +38,23 @@ class TestProviderImpl:
|
||||||
)
|
)
|
||||||
return ProviderImplConfig(run_config=run_config)
|
return ProviderImplConfig(run_config=run_config)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_malformed_metrics(self):
|
||||||
|
"""Create a mock configuration with invalid metrics URL for testing."""
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test_image",
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="test_provider_malformed_metrics",
|
||||||
|
provider_type="test_type3",
|
||||||
|
config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return ProviderImplConfig(run_config=run_config)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_deps(self):
|
def mock_deps(self):
|
||||||
"""Create mock dependencies."""
|
"""Create mock dependencies."""
|
||||||
|
@ -111,3 +129,11 @@ class TestProviderImpl:
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Provider nonexistent not found"):
|
with pytest.raises(ValueError, match="Provider nonexistent not found"):
|
||||||
await provider_impl.inspect_provider("nonexistent")
|
await provider_impl.inspect_provider("nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inspect_provider_malformed_metrics(self, mock_config_malformed_metrics, mock_deps):
|
||||||
|
"""Test inspect_provider with invalid metrics URL raises validation error."""
|
||||||
|
provider_impl = ProviderImpl(mock_config_malformed_metrics, mock_deps)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await provider_impl.inspect_provider("test_provider_malformed_metrics")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue