mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
- add field "metrcis" under API "/providers" - each provider can config "metrics" in run.yaml in "config" - if no "metrics": /provider/<provider_id> shows "metrics: null" in response - if has "metrics": /provider/<provider_id> show result in response - if has "metrics" but is not string type and not httpurl, raise ValidationError - add unit tests for providers - update "docs" Signed-off-by: Wen Zhou <wenzhou@redhat.com>
327 lines
12 KiB
Python
327 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import yaml
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
|
|
class TestProviderMetrics:
|
|
"""Test suite for provider metrics."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_with_valid_metrics_in_config(self):
|
|
"""Test provider with valid metrics in config."""
|
|
run_config = StackRunConfig(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="test_provider",
|
|
provider_type="test_type",
|
|
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
|
)
|
|
]
|
|
},
|
|
)
|
|
provider_config = ProviderImplConfig(run_config=run_config)
|
|
provider_impl = ProviderImpl(provider_config, {})
|
|
|
|
response = await provider_impl.list_providers()
|
|
assert len(response.data) == 1
|
|
|
|
provider_info = response.data[0]
|
|
assert provider_info.provider_id == "test_provider"
|
|
assert provider_info.provider_type == "test_type"
|
|
assert provider_info.metrics == "http://localhost:9090/metrics"
|
|
assert "metrics" not in provider_info.config
|
|
assert provider_info.config["url"] == "http://localhost:8000"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_with_invalid_metrics_in_config(self):
|
|
"""Test that invalid metrics in config fails when access /providers."""
|
|
run_config = StackRunConfig(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="test_provider",
|
|
provider_type="test_type",
|
|
config={"url": "http://localhost:8000", "metrics": "abcde-llama-stack"},
|
|
)
|
|
]
|
|
},
|
|
)
|
|
provider_config = ProviderImplConfig(run_config=run_config)
|
|
provider_impl = ProviderImpl(provider_config, {})
|
|
|
|
with pytest.raises(ValidationError):
|
|
await provider_impl.list_providers()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_without_metrics_in_config(self):
|
|
"""Test provider without metrics in config returns None."""
|
|
run_config = StackRunConfig(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="test_provider", provider_type="test_type", config={"url": "http://localhost:8000"}
|
|
)
|
|
]
|
|
},
|
|
)
|
|
provider_config = ProviderImplConfig(run_config=run_config)
|
|
provider_impl = ProviderImpl(provider_config, {})
|
|
|
|
response = await provider_impl.list_providers()
|
|
assert len(response.data) == 1
|
|
|
|
provider_info = response.data[0]
|
|
assert provider_info.provider_id == "test_provider"
|
|
assert provider_info.provider_type == "test_type"
|
|
assert provider_info.metrics is None
|
|
assert "metrics" not in provider_info.config
|
|
assert provider_info.config["url"] == "http://localhost:8000"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_inspect_provider_with_metrics(self):
|
|
"""Test inspect_provider returns correct metrics info."""
|
|
run_config = StackRunConfig(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="test_provider",
|
|
provider_type="test_type",
|
|
config={"url": "http://localhost:8000", "metrics": "http://localhost:9090/metrics"},
|
|
)
|
|
]
|
|
},
|
|
)
|
|
provider_config = ProviderImplConfig(run_config=run_config)
|
|
provider_impl = ProviderImpl(provider_config, {})
|
|
|
|
# Test the inspect_provider API
|
|
provider_info = await provider_impl.inspect_provider("test_provider")
|
|
assert provider_info.provider_id == "test_provider"
|
|
assert provider_info.metrics == "http://localhost:9090/metrics"
|
|
assert "metrics" not in provider_info.config
|
|
|
|
|
|
class SampleConfig(BaseModel):
|
|
foo: str = Field(
|
|
default="bar",
|
|
description="foo",
|
|
)
|
|
|
|
@classmethod
|
|
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
|
return {
|
|
"foo": "baz",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_providers():
|
|
"""Mock the available_providers function to return test providers."""
|
|
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
|
|
mock.return_value = [
|
|
ProviderSpec(
|
|
provider_type="test_provider",
|
|
api=Api.inference,
|
|
adapter_type="test_adapter",
|
|
config_class="test_provider.config.TestProviderConfig",
|
|
)
|
|
]
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def base_config(tmp_path):
|
|
"""Create a base StackRunConfig with common settings."""
|
|
return StackRunConfig(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="sample_provider",
|
|
provider_type="sample",
|
|
config=SampleConfig.sample_run_config(),
|
|
)
|
|
]
|
|
},
|
|
external_providers_dir=str(tmp_path),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def provider_spec_yaml():
|
|
"""Common provider spec YAML for testing."""
|
|
return """
|
|
adapter:
|
|
adapter_type: test_provider
|
|
config_class: test_provider.config.TestProviderConfig
|
|
module: test_provider
|
|
api_dependencies:
|
|
- safety
|
|
"""
|
|
|
|
|
|
@pytest.fixture
|
|
def inline_provider_spec_yaml():
|
|
"""Common inline provider spec YAML for testing."""
|
|
return """
|
|
module: test_provider
|
|
config_class: test_provider.config.TestProviderConfig
|
|
pip_packages:
|
|
- test-package
|
|
api_dependencies:
|
|
- safety
|
|
optional_api_dependencies:
|
|
- vector_io
|
|
provider_data_validator: test_provider.validator.TestValidator
|
|
container_image: test-image:latest
|
|
"""
|
|
|
|
|
|
@pytest.fixture
|
|
def api_directories(tmp_path):
|
|
"""Create the API directory structure for testing."""
|
|
# Create remote provider directory
|
|
remote_inference_dir = tmp_path / "remote" / "inference"
|
|
remote_inference_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create inline provider directory
|
|
inline_inference_dir = tmp_path / "inline" / "inference"
|
|
inline_inference_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
return remote_inference_dir, inline_inference_dir
|
|
|
|
|
|
class TestProviderRegistry:
|
|
"""Test suite for provider registry functionality."""
|
|
|
|
def test_builtin_providers(self, mock_providers):
|
|
"""Test loading built-in providers."""
|
|
registry = get_provider_registry(None)
|
|
|
|
assert Api.inference in registry
|
|
assert "test_provider" in registry[Api.inference]
|
|
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
|
|
assert registry[Api.inference]["test_provider"].api == Api.inference
|
|
|
|
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
|
|
"""Test loading external remote providers from YAML files."""
|
|
remote_dir, _ = api_directories
|
|
with open(remote_dir / "test_provider.yaml", "w") as f:
|
|
f.write(provider_spec_yaml)
|
|
|
|
registry = get_provider_registry(base_config)
|
|
assert len(registry[Api.inference]) == 2
|
|
|
|
assert Api.inference in registry
|
|
assert "remote::test_provider" in registry[Api.inference]
|
|
provider = registry[Api.inference]["remote::test_provider"]
|
|
assert provider.adapter.adapter_type == "test_provider"
|
|
assert provider.adapter.module == "test_provider"
|
|
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
|
assert Api.safety in provider.api_dependencies
|
|
|
|
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
|
"""Test loading external inline providers from YAML files."""
|
|
_, inline_dir = api_directories
|
|
with open(inline_dir / "test_provider.yaml", "w") as f:
|
|
f.write(inline_provider_spec_yaml)
|
|
|
|
registry = get_provider_registry(base_config)
|
|
assert len(registry[Api.inference]) == 2
|
|
|
|
assert Api.inference in registry
|
|
assert "inline::test_provider" in registry[Api.inference]
|
|
provider = registry[Api.inference]["inline::test_provider"]
|
|
assert provider.provider_type == "inline::test_provider"
|
|
assert provider.module == "test_provider"
|
|
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
|
assert provider.pip_packages == ["test-package"]
|
|
assert Api.safety in provider.api_dependencies
|
|
assert Api.vector_io in provider.optional_api_dependencies
|
|
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
|
|
assert provider.container_image == "test-image:latest"
|
|
|
|
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of invalid YAML files."""
|
|
remote_dir, inline_dir = api_directories
|
|
with open(remote_dir / "invalid.yaml", "w") as f:
|
|
f.write("invalid: yaml: content: -")
|
|
with open(inline_dir / "invalid.yaml", "w") as f:
|
|
f.write("invalid: yaml: content: -")
|
|
|
|
with pytest.raises(yaml.YAMLError):
|
|
get_provider_registry(base_config)
|
|
|
|
def test_missing_directory(self, mock_providers):
|
|
"""Test handling of missing external providers directory."""
|
|
config = StackRunConfig(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="sample_provider",
|
|
provider_type="sample",
|
|
config=SampleConfig.sample_run_config(),
|
|
)
|
|
]
|
|
},
|
|
external_providers_dir="/nonexistent/dir",
|
|
)
|
|
with pytest.raises(FileNotFoundError):
|
|
get_provider_registry(config)
|
|
|
|
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of empty API directory."""
|
|
registry = get_provider_registry(base_config)
|
|
assert len(registry[Api.inference]) == 1 # Only built-in provider
|
|
|
|
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of malformed remote provider spec (missing required fields)."""
|
|
remote_dir, _ = api_directories
|
|
malformed_spec = """
|
|
adapter:
|
|
adapter_type: test_provider
|
|
# Missing required fields
|
|
api_dependencies:
|
|
- safety
|
|
"""
|
|
with open(remote_dir / "malformed.yaml", "w") as f:
|
|
f.write(malformed_spec)
|
|
|
|
with pytest.raises(ValidationError):
|
|
get_provider_registry(base_config)
|
|
|
|
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of malformed inline provider spec (missing required fields)."""
|
|
_, inline_dir = api_directories
|
|
malformed_spec = """
|
|
module: test_provider
|
|
# Missing required config_class
|
|
pip_packages:
|
|
- test-package
|
|
"""
|
|
with open(inline_dir / "malformed.yaml", "w") as f:
|
|
f.write(malformed_spec)
|
|
|
|
with pytest.raises(KeyError) as exc_info:
|
|
get_provider_registry(base_config)
|
|
assert "config_class" in str(exc_info.value)
|