mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
add tests
This commit is contained in:
parent
d11edf6fee
commit
1885beb864
2 changed files with 758 additions and 0 deletions
354
tests/integration/providers/test_dynamic_providers.py
Normal file
354
tests/integration/providers/test_dynamic_providers.py
Normal file
|
|
@ -0,0 +1,354 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicProviderManagement:
|
||||||
|
"""Integration tests for dynamic provider registration, update, and unregistration."""
|
||||||
|
|
||||||
|
def test_register_and_unregister_inference_provider(
|
||||||
|
self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient
|
||||||
|
):
|
||||||
|
"""Test registering and unregistering an inference provider."""
|
||||||
|
provider_id = "test-dynamic-inference"
|
||||||
|
|
||||||
|
# Clean up if exists from previous test
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register a new inference provider (using Ollama since it's available in test setup)
|
||||||
|
response = llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
"api_token": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify registration
|
||||||
|
assert response.provider.provider_id == provider_id
|
||||||
|
assert response.provider.api == "inference"
|
||||||
|
assert response.provider.provider_type == "remote::ollama"
|
||||||
|
assert response.provider.status in ["connected", "initializing"]
|
||||||
|
|
||||||
|
# Verify provider appears in list
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id in provider_ids
|
||||||
|
|
||||||
|
# Verify we can retrieve it
|
||||||
|
provider = llama_stack_client.providers.retrieve(provider_id)
|
||||||
|
assert provider.provider_id == provider_id
|
||||||
|
|
||||||
|
# Unregister the provider
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
# Verify it's no longer in the list
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id not in provider_ids
|
||||||
|
|
||||||
|
def test_register_and_unregister_vector_store_provider(
|
||||||
|
self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient
|
||||||
|
):
|
||||||
|
"""Test registering and unregistering a vector store provider."""
|
||||||
|
provider_id = "test-dynamic-vector-store"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register a new vector_io provider (using Faiss inline)
|
||||||
|
response = llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="vector_io",
|
||||||
|
provider_type="inline::faiss",
|
||||||
|
config={
|
||||||
|
"embedding_dimension": 768,
|
||||||
|
"kvstore": {
|
||||||
|
"type": "sqlite",
|
||||||
|
"namespace": f"test_vector_store_{provider_id}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify registration
|
||||||
|
assert response.provider.provider_id == provider_id
|
||||||
|
assert response.provider.api == "vector_io"
|
||||||
|
assert response.provider.provider_type == "inline::faiss"
|
||||||
|
|
||||||
|
# Verify provider appears in list
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id in provider_ids
|
||||||
|
|
||||||
|
# Unregister
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
# Verify removal
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id not in provider_ids
|
||||||
|
|
||||||
|
def test_update_provider_config(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test updating a provider's configuration."""
|
||||||
|
provider_id = "test-update-config"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register provider
|
||||||
|
llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
"api_token": "old-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the configuration
|
||||||
|
response = llama_stack_client.providers.update(
|
||||||
|
provider_id=provider_id,
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
"api_token": "new-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify update
|
||||||
|
assert response.provider.provider_id == provider_id
|
||||||
|
assert response.provider.config["api_token"] == "new-token"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
def test_update_provider_attributes(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test updating a provider's ABAC attributes."""
|
||||||
|
provider_id = "test-update-attributes"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register provider with initial attributes
|
||||||
|
llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
},
|
||||||
|
attributes={"team": ["team-a"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update attributes
|
||||||
|
response = llama_stack_client.providers.update(
|
||||||
|
provider_id=provider_id,
|
||||||
|
attributes={"team": ["team-a", "team-b"], "environment": ["test"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify attributes were updated
|
||||||
|
assert response.provider.attributes["team"] == ["team-a", "team-b"]
|
||||||
|
assert response.provider.attributes["environment"] == ["test"]
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
def test_test_provider_connection(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test the connection testing functionality."""
|
||||||
|
provider_id = "test-connection-check"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register provider
|
||||||
|
llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the connection
|
||||||
|
response = llama_stack_client.providers.test_connection(provider_id)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert hasattr(response, "success")
|
||||||
|
assert hasattr(response, "health")
|
||||||
|
|
||||||
|
# Note: success may be True or False depending on whether Ollama is actually running
|
||||||
|
# but the test should at least verify the API works
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
def test_register_duplicate_provider_fails(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test that registering a duplicate provider ID fails."""
|
||||||
|
provider_id = "test-duplicate"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register first provider
|
||||||
|
llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={"url": "http://localhost:11434"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to register with same ID - should fail
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={"url": "http://localhost:11435"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message mentions the provider already exists
|
||||||
|
assert "already exists" in str(exc_info.value).lower() or "duplicate" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
def test_unregister_nonexistent_provider_fails(
|
||||||
|
self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient
|
||||||
|
):
|
||||||
|
"""Test that unregistering a non-existent provider fails."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
llama_stack_client.providers.unregister("nonexistent-provider-12345")
|
||||||
|
|
||||||
|
# Verify error message mentions provider not found
|
||||||
|
assert "not found" in str(exc_info.value).lower() or "does not exist" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_update_nonexistent_provider_fails(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test that updating a non-existent provider fails."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
llama_stack_client.providers.update(
|
||||||
|
provider_id="nonexistent-provider-12345",
|
||||||
|
config={"url": "http://localhost:11434"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message mentions provider not found
|
||||||
|
assert "not found" in str(exc_info.value).lower() or "does not exist" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_provider_lifecycle_with_inference(
|
||||||
|
self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient
|
||||||
|
):
|
||||||
|
"""Test full lifecycle: register, use for inference (if Ollama available), update, unregister."""
|
||||||
|
provider_id = "test-lifecycle-inference"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register provider
|
||||||
|
response = llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.provider.status in ["connected", "initializing"]
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
conn_test = llama_stack_client.providers.test_connection(provider_id)
|
||||||
|
assert hasattr(conn_test, "success")
|
||||||
|
|
||||||
|
# Update configuration
|
||||||
|
update_response = llama_stack_client.providers.update(
|
||||||
|
provider_id=provider_id,
|
||||||
|
config={
|
||||||
|
"url": "http://localhost:11434",
|
||||||
|
"api_token": "updated-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert update_response.provider.config["api_token"] == "updated-token"
|
||||||
|
|
||||||
|
# Unregister
|
||||||
|
llama_stack_client.providers.unregister(provider_id)
|
||||||
|
|
||||||
|
# Verify it's gone
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id not in provider_ids
|
||||||
|
|
||||||
|
def test_multiple_providers_same_type(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test registering multiple providers of the same type with different IDs."""
|
||||||
|
provider_id_1 = "test-multi-ollama-1"
|
||||||
|
provider_id_2 = "test-multi-ollama-2"
|
||||||
|
|
||||||
|
# Clean up if exists
|
||||||
|
for pid in [provider_id_1, provider_id_2]:
|
||||||
|
try:
|
||||||
|
llama_stack_client.providers.unregister(pid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register first provider
|
||||||
|
response1 = llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id_1,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={"url": "http://localhost:11434"},
|
||||||
|
)
|
||||||
|
assert response1.provider.provider_id == provider_id_1
|
||||||
|
|
||||||
|
# Register second provider with same type but different ID
|
||||||
|
response2 = llama_stack_client.providers.register(
|
||||||
|
provider_id=provider_id_2,
|
||||||
|
api="inference",
|
||||||
|
provider_type="remote::ollama",
|
||||||
|
config={"url": "http://localhost:11434"},
|
||||||
|
)
|
||||||
|
assert response2.provider.provider_id == provider_id_2
|
||||||
|
|
||||||
|
# Verify both are in the list
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id_1 in provider_ids
|
||||||
|
assert provider_id_2 in provider_ids
|
||||||
|
|
||||||
|
# Clean up both
|
||||||
|
llama_stack_client.providers.unregister(provider_id_1)
|
||||||
|
llama_stack_client.providers.unregister(provider_id_2)
|
||||||
|
|
||||||
|
# Verify both are gone
|
||||||
|
providers = llama_stack_client.providers.list()
|
||||||
|
provider_ids = [p.provider_id for p in providers]
|
||||||
|
assert provider_id_1 not in provider_ids
|
||||||
|
assert provider_id_2 not in provider_ids
|
||||||
404
tests/unit/core/test_dynamic_providers.py
Normal file
404
tests/unit/core/test_dynamic_providers.py
Normal file
|
|
@ -0,0 +1,404 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.providers.connection import ProviderConnectionInfo, ProviderConnectionStatus, ProviderHealth
|
||||||
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
|
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||||
|
from llama_stack.core.storage.datatypes import KVStoreReference, ServerStoresConfig, SqliteKVStoreConfig, StorageConfig
|
||||||
|
from llama_stack.providers.datatypes import Api, HealthStatus
|
||||||
|
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def kvstore(tmp_path):
|
||||||
|
"""Create a temporary kvstore for testing."""
|
||||||
|
db_path = tmp_path / "test_providers.db"
|
||||||
|
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
|
||||||
|
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||||
|
await kvstore.initialize()
|
||||||
|
yield kvstore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def provider_impl(kvstore, tmp_path):
|
||||||
|
"""Create a ProviderImpl instance with mocked dependencies."""
|
||||||
|
db_path = (tmp_path / "test_providers.db").as_posix()
|
||||||
|
|
||||||
|
# Create storage config with required structure
|
||||||
|
storage_config = StorageConfig(
|
||||||
|
backends={
|
||||||
|
"default": SqliteKVStoreConfig(db_path=db_path),
|
||||||
|
},
|
||||||
|
stores=ServerStoresConfig(
|
||||||
|
metadata=KVStoreReference(backend="default", namespace="test_metadata"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create minimal run config with storage
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
apis=[],
|
||||||
|
providers={},
|
||||||
|
storage=storage_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock provider registry
|
||||||
|
mock_provider_registry = MagicMock()
|
||||||
|
|
||||||
|
config = ProviderImplConfig(
|
||||||
|
run_config=run_config,
|
||||||
|
provider_registry=mock_provider_registry,
|
||||||
|
dist_registry=None,
|
||||||
|
policy=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
impl = ProviderImpl(config, deps={})
|
||||||
|
|
||||||
|
# Manually set the kvstore instead of going through initialize
|
||||||
|
# This avoids the complex backend registration logic
|
||||||
|
impl.kvstore = kvstore
|
||||||
|
impl.provider_registry = mock_provider_registry
|
||||||
|
impl.dist_registry = None
|
||||||
|
impl.policy = []
|
||||||
|
|
||||||
|
yield impl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestDynamicProviderManagement:
|
||||||
|
"""Unit tests for dynamic provider registration, update, and unregistration."""
|
||||||
|
|
||||||
|
async def test_register_inference_provider(self, provider_impl):
|
||||||
|
"""Test registering a new inference provider."""
|
||||||
|
# Mock the provider instantiation
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register a mock inference provider
|
||||||
|
response = await provider_impl.register_provider(
|
||||||
|
provider_id="test-inference-1",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "test-key", "url": "https://api.openai.com/v1"},
|
||||||
|
attributes={"team": ["test-team"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert response.provider.provider_id == "test-inference-1"
|
||||||
|
assert response.provider.api == Api.inference.value
|
||||||
|
assert response.provider.provider_type == "remote::openai"
|
||||||
|
assert response.provider.status == ProviderConnectionStatus.connected
|
||||||
|
assert response.provider.config["api_key"] == "test-key"
|
||||||
|
assert response.provider.attributes == {"team": ["test-team"]}
|
||||||
|
|
||||||
|
# Verify provider is stored
|
||||||
|
assert "test-inference-1" in provider_impl.dynamic_providers
|
||||||
|
assert "test-inference-1" in provider_impl.dynamic_provider_impls
|
||||||
|
|
||||||
|
async def test_register_vector_store_provider(self, provider_impl):
|
||||||
|
"""Test registering a new vector store provider."""
|
||||||
|
# Mock the provider instantiation
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register a mock vector_io provider
|
||||||
|
response = await provider_impl.register_provider(
|
||||||
|
provider_id="test-vector-store-1",
|
||||||
|
api=Api.vector_io.value,
|
||||||
|
provider_type="inline::faiss",
|
||||||
|
config={"dimension": 768, "index_path": "/tmp/faiss_index"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert response.provider.provider_id == "test-vector-store-1"
|
||||||
|
assert response.provider.api == Api.vector_io.value
|
||||||
|
assert response.provider.provider_type == "inline::faiss"
|
||||||
|
assert response.provider.status == ProviderConnectionStatus.connected
|
||||||
|
assert response.provider.config["dimension"] == 768
|
||||||
|
|
||||||
|
async def test_register_duplicate_provider_fails(self, provider_impl):
|
||||||
|
"""Test that registering a duplicate provider_id fails."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register first provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-duplicate",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "key1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to register with same ID
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-duplicate",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "key2"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_update_provider_config(self, provider_impl):
|
||||||
|
"""Test updating a provider's configuration."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-update",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "old-key", "timeout": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update configuration
|
||||||
|
response = await provider_impl.update_provider(
|
||||||
|
provider_id="test-update",
|
||||||
|
config={"api_key": "new-key", "timeout": 60},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify updated config
|
||||||
|
assert response.provider.provider_id == "test-update"
|
||||||
|
assert response.provider.config["api_key"] == "new-key"
|
||||||
|
assert response.provider.config["timeout"] == 60
|
||||||
|
assert response.provider.status == ProviderConnectionStatus.connected
|
||||||
|
|
||||||
|
async def test_update_provider_attributes(self, provider_impl):
|
||||||
|
"""Test updating a provider's attributes."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider with initial attributes
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-attributes",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "test-key"},
|
||||||
|
attributes={"team": ["team-a"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update attributes
|
||||||
|
response = await provider_impl.update_provider(
|
||||||
|
provider_id="test-attributes",
|
||||||
|
attributes={"team": ["team-a", "team-b"], "environment": ["prod"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify updated attributes
|
||||||
|
assert response.provider.attributes == {"team": ["team-a", "team-b"], "environment": ["prod"]}
|
||||||
|
|
||||||
|
async def test_update_nonexistent_provider_fails(self, provider_impl):
|
||||||
|
"""Test that updating a non-existent provider fails."""
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await provider_impl.update_provider(
|
||||||
|
provider_id="nonexistent",
|
||||||
|
config={"api_key": "new-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_unregister_provider(self, provider_impl):
|
||||||
|
"""Test unregistering a provider."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
mock_provider_instance.shutdown = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-unregister",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "test-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify it exists
|
||||||
|
assert "test-unregister" in provider_impl.dynamic_providers
|
||||||
|
|
||||||
|
# Unregister provider
|
||||||
|
await provider_impl.unregister_provider(provider_id="test-unregister")
|
||||||
|
|
||||||
|
# Verify it's removed
|
||||||
|
assert "test-unregister" not in provider_impl.dynamic_providers
|
||||||
|
assert "test-unregister" not in provider_impl.dynamic_provider_impls
|
||||||
|
|
||||||
|
# Verify shutdown was called
|
||||||
|
mock_provider_instance.shutdown.assert_called_once()
|
||||||
|
|
||||||
|
async def test_unregister_nonexistent_provider_fails(self, provider_impl):
|
||||||
|
"""Test that unregistering a non-existent provider fails."""
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
await provider_impl.unregister_provider(provider_id="nonexistent")
|
||||||
|
|
||||||
|
async def test_test_provider_connection_healthy(self, provider_impl):
|
||||||
|
"""Test testing a healthy provider connection."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK, "message": "All good"})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-health",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "test-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
response = await provider_impl.test_provider_connection(provider_id="test-health")
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert response.success is True
|
||||||
|
assert response.health["status"] == HealthStatus.OK
|
||||||
|
assert response.health["message"] == "All good"
|
||||||
|
assert response.error_message is None
|
||||||
|
|
||||||
|
async def test_test_provider_connection_unhealthy(self, provider_impl):
|
||||||
|
"""Test testing an unhealthy provider connection."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(
|
||||||
|
return_value={"status": HealthStatus.ERROR, "message": "Connection failed"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-unhealthy",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "invalid-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
response = await provider_impl.test_provider_connection(provider_id="test-unhealthy")
|
||||||
|
|
||||||
|
# Verify response shows unhealthy status
|
||||||
|
assert response.success is False
|
||||||
|
assert response.health["status"] == HealthStatus.ERROR
|
||||||
|
|
||||||
|
async def test_list_providers_includes_dynamic(self, provider_impl):
|
||||||
|
"""Test that list_providers includes dynamically registered providers."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register multiple providers
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="dynamic-1",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "key1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="dynamic-2",
|
||||||
|
api=Api.vector_io.value,
|
||||||
|
provider_type="inline::faiss",
|
||||||
|
config={"dimension": 768},
|
||||||
|
)
|
||||||
|
|
||||||
|
# List all providers
|
||||||
|
response = await provider_impl.list_providers()
|
||||||
|
|
||||||
|
# Verify both dynamic providers are in the list
|
||||||
|
provider_ids = [p.provider_id for p in response.data]
|
||||||
|
assert "dynamic-1" in provider_ids
|
||||||
|
assert "dynamic-2" in provider_ids
|
||||||
|
|
||||||
|
async def test_inspect_provider(self, provider_impl):
|
||||||
|
"""Test inspecting a specific provider."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-inspect",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "test-key", "model": "gpt-4"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the stored health info to reflect OK status
|
||||||
|
# (In reality, the health check happens during registration,
|
||||||
|
# but our mock may not have been properly called)
|
||||||
|
conn_info = provider_impl.dynamic_providers["test-inspect"]
|
||||||
|
from llama_stack.apis.providers.connection import ProviderHealth
|
||||||
|
|
||||||
|
conn_info.health = ProviderHealth.from_health_response({"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
# Inspect provider
|
||||||
|
provider_info = await provider_impl.inspect_provider(provider_id="test-inspect")
|
||||||
|
|
||||||
|
# Verify provider info
|
||||||
|
assert provider_info.provider_id == "test-inspect"
|
||||||
|
assert provider_info.api == Api.inference.value
|
||||||
|
assert provider_info.provider_type == "remote::openai"
|
||||||
|
assert provider_info.config["model"] == "gpt-4"
|
||||||
|
assert provider_info.health["status"] == HealthStatus.OK
|
||||||
|
|
||||||
|
async def test_provider_persistence(self, provider_impl, kvstore, tmp_path):
|
||||||
|
"""Test that providers persist across restarts."""
|
||||||
|
mock_provider_instance = AsyncMock()
|
||||||
|
mock_provider_instance.health = AsyncMock(return_value={"status": HealthStatus.OK})
|
||||||
|
|
||||||
|
with patch.object(provider_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
# Register provider
|
||||||
|
await provider_impl.register_provider(
|
||||||
|
provider_id="test-persist",
|
||||||
|
api=Api.inference.value,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
config={"api_key": "persist-key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new provider impl (simulating restart) - reuse the same kvstore
|
||||||
|
db_path = (tmp_path / "test_providers.db").as_posix()
|
||||||
|
|
||||||
|
storage_config = StorageConfig(
|
||||||
|
backends={
|
||||||
|
"default": SqliteKVStoreConfig(db_path=db_path),
|
||||||
|
},
|
||||||
|
stores=ServerStoresConfig(
|
||||||
|
metadata=KVStoreReference(backend="default", namespace="test_metadata"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
apis=[],
|
||||||
|
providers={},
|
||||||
|
storage=storage_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = ProviderImplConfig(
|
||||||
|
run_config=run_config,
|
||||||
|
provider_registry=MagicMock(),
|
||||||
|
dist_registry=None,
|
||||||
|
policy=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_impl = ProviderImpl(config, deps={})
|
||||||
|
|
||||||
|
# Manually set the kvstore (reusing the same one)
|
||||||
|
new_impl.kvstore = kvstore
|
||||||
|
new_impl.provider_registry = MagicMock()
|
||||||
|
new_impl.dist_registry = None
|
||||||
|
new_impl.policy = []
|
||||||
|
|
||||||
|
# Load providers from kvstore
|
||||||
|
with patch.object(new_impl, "_instantiate_provider", return_value=mock_provider_instance):
|
||||||
|
await new_impl._load_dynamic_providers()
|
||||||
|
|
||||||
|
# Verify the provider was loaded from kvstore
|
||||||
|
assert "test-persist" in new_impl.dynamic_providers
|
||||||
|
assert new_impl.dynamic_providers["test-persist"].config["api_key"] == "persist-key"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue