mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
test: add unit tests to verify graceful stack startup
Added tests to verify that the stack: 1. Continues initialization when providers fail to register models 2. Skips invalid models instead of crashing 3. Handles provider listing failures gracefully 4. Maintains partial functionality with mixed success/failure Example: - OpenAI provider fails to list models - Stack logs error and continues with registered models - Other providers remain functional This prevents the entire stack from crashing when: - Provider API keys are invalid - Models are misconfigured - Provider API is temporarily unavailable
This commit is contained in:
parent
a271e3abae
commit
a1010e52e9
2 changed files with 304 additions and 207 deletions
304
tests/unit/core/test_stack.py
Normal file
304
tests/unit/core/test_stack.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelInput, ModelType
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.stack import register_resources
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_caplog(caplog):
|
||||
"""Configure caplog to capture all log levels"""
|
||||
caplog.set_level("DEBUG")
|
||||
|
||||
|
||||
def assert_graceful_failure(message: str) -> Callable[[Exception], None]:
|
||||
"""Create a failure handler with consistent error message format.
|
||||
|
||||
Args:
|
||||
message: Error message with optional example
|
||||
"""
|
||||
|
||||
def _fail(_: Exception) -> None:
|
||||
pytest.fail(message)
|
||||
|
||||
return _fail
|
||||
|
||||
|
||||
def create_mock_impl(
|
||||
register_result: Any = None,
|
||||
register_error: Exception | None = None,
|
||||
list_result: Any = None,
|
||||
list_error: Exception | None = None,
|
||||
) -> AsyncMock:
|
||||
"""Create a mock implementation with configurable behavior.
|
||||
|
||||
Args:
|
||||
register_result: Return value for register_model
|
||||
register_error: Error to raise from register_model
|
||||
list_result: Return value for list_models
|
||||
list_error: Error to raise from list_models
|
||||
"""
|
||||
mock_impl = AsyncMock()
|
||||
if register_error:
|
||||
mock_impl.register_model = AsyncMock(side_effect=register_error)
|
||||
elif isinstance(register_result, list):
|
||||
mock_impl.register_model = AsyncMock(side_effect=register_result)
|
||||
else:
|
||||
mock_impl.register_model = AsyncMock(return_value=register_result)
|
||||
|
||||
if list_error:
|
||||
mock_impl.list_models = AsyncMock(side_effect=list_error)
|
||||
else:
|
||||
mock_impl.list_models = AsyncMock(return_value=list_result or [])
|
||||
|
||||
return mock_impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a valid model for testing."""
|
||||
from llama_stack.apis.models import ModelInput # Import the correct type
|
||||
|
||||
return ModelInput( # Use ModelInput instead of Model
|
||||
provider_id="test_provider",
|
||||
provider_model_id="test_model",
|
||||
identifier="test_model",
|
||||
model_id="test_model", # Required field
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
"""Create a valid provider for testing."""
|
||||
return Provider(
|
||||
provider_id="test_provider",
|
||||
provider_type="remote::test", # Valid format: namespace::name
|
||||
config={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_run_config(mock_model, mock_provider): # Add mock_provider as dependency
|
||||
return StackRunConfig(
|
||||
image_name="test",
|
||||
apis=["inference"],
|
||||
providers={"inference": [mock_provider]}, # Use the Provider object directly
|
||||
models=[mock_model],
|
||||
)
|
||||
|
||||
|
||||
async def test_register_resources_success(mock_run_config, mock_model):
|
||||
"""Test successful registration of resources."""
|
||||
mock_impl = create_mock_impl(
|
||||
register_result=mock_model,
|
||||
list_result=[mock_model],
|
||||
)
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
try:
|
||||
await register_resources(mock_run_config, impls)
|
||||
except AttributeError:
|
||||
assert_graceful_failure("Stack interrupted initialization: tried to access model fields in wrong format")(None)
|
||||
|
||||
mock_impl.register_model.assert_called_once()
|
||||
mock_impl.list_models.assert_called_once()
|
||||
|
||||
|
||||
async def test_register_resources_failed_registration(caplog):
|
||||
"""Test that stack continues when model registration fails.
|
||||
|
||||
This test demonstrates how the stack handles validation errors:
|
||||
|
||||
Before fix:
|
||||
- A provider failing to validate a model would crash the entire stack
|
||||
- Example: OpenAI provider with invalid API key would prevent startup
|
||||
- No way to start stack with other working providers
|
||||
|
||||
After fix:
|
||||
- Provider validation errors are caught and logged
|
||||
- Stack continues initializing with other providers
|
||||
- Failed provider is skipped but doesn't block others
|
||||
|
||||
Test strategy:
|
||||
1. Create an invalid model (wrong type) to trigger validation
|
||||
2. Create a valid provider to show it's not provider's fault
|
||||
3. Verify stack handles validation error and continues
|
||||
"""
|
||||
# Create a valid model that will fail registration
|
||||
invalid_model = ModelInput(
|
||||
provider_id="test_provider",
|
||||
provider_model_id="test_model",
|
||||
identifier="test_model",
|
||||
model_id="test_model", # Required field
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
# Create a valid provider
|
||||
valid_provider = Provider(
|
||||
provider_id="test_provider",
|
||||
provider_type="remote::test",
|
||||
config={},
|
||||
)
|
||||
|
||||
# Create config with the model
|
||||
mock_run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
apis=["inference"],
|
||||
providers={"inference": [valid_provider]},
|
||||
models=[invalid_model],
|
||||
)
|
||||
|
||||
mock_impl = create_mock_impl(
|
||||
register_error=ValueError(
|
||||
"Provider failed to validate model: expected ModelInput but got Model\n"
|
||||
"This would previously crash the stack, but should now be handled gracefully"
|
||||
),
|
||||
)
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
# Before fix: Stack would crash here
|
||||
# After fix: Should handle error and continue
|
||||
try:
|
||||
await register_resources(mock_run_config, impls)
|
||||
except Exception as e:
|
||||
assert_graceful_failure("Stack interrupted initialization: provider received model in wrong format")(e)
|
||||
|
||||
# Verify registration was attempted despite validation issues
|
||||
(
|
||||
mock_impl.register_model.assert_called_once_with(
|
||||
**{k: getattr(invalid_model, k) for k in invalid_model.model_dump().keys()}
|
||||
),
|
||||
"Provider should attempt model registration even with invalid model type",
|
||||
)
|
||||
|
||||
# Verify stack continues operating after validation failure
|
||||
(
|
||||
mock_impl.list_models.assert_called_once(),
|
||||
"Stack should continue normal operation after handling model validation failure",
|
||||
)
|
||||
|
||||
# Verify error was logged
|
||||
assert "Failed to register models" in caplog.text, "Error should be logged when model registration fails"
|
||||
assert "Provider failed to validate model" in caplog.text, "Specific error message should be logged"
|
||||
|
||||
|
||||
async def test_register_resources_failed_listing(mock_run_config, mock_model):
|
||||
"""Test that stack continues when model listing fails."""
|
||||
mock_impl = create_mock_impl(
|
||||
register_result=mock_model,
|
||||
list_error=ValueError("Listing failed"),
|
||||
)
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
# Should not raise exception
|
||||
try:
|
||||
await register_resources(mock_run_config, impls)
|
||||
except Exception as e:
|
||||
assert_graceful_failure("Stack interrupted initialization: provider failed to list available models")(e)
|
||||
|
||||
# Verify registration completed successfully
|
||||
(
|
||||
mock_impl.register_model.assert_called_once_with(
|
||||
**{k: getattr(mock_model, k) for k in mock_model.model_dump().keys()}
|
||||
),
|
||||
"register_model() should complete successfully before the listing failure",
|
||||
)
|
||||
|
||||
# Verify listing was attempted
|
||||
(
|
||||
mock_impl.list_models.assert_called_once(),
|
||||
"list_models() should be called and its failure should be handled gracefully",
|
||||
)
|
||||
|
||||
|
||||
async def test_register_resources_mixed_success(mock_run_config):
|
||||
"""Test mixed success/failure scenario with multiple models."""
|
||||
# Create two models
|
||||
model1 = Model(
|
||||
provider_id="test_provider",
|
||||
provider_model_id="model1",
|
||||
identifier="model1",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model2 = Model(
|
||||
provider_id="test_provider",
|
||||
provider_model_id="model2",
|
||||
identifier="model2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
# Update run config to include both models
|
||||
mock_run_config.models = [model1, model2]
|
||||
|
||||
mock_impl = create_mock_impl(
|
||||
register_result=[model1, ValueError("Second registration failed")],
|
||||
list_result=[model1], # Only first model listed
|
||||
)
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
# Should not raise exception
|
||||
try:
|
||||
await register_resources(mock_run_config, impls)
|
||||
except Exception as e:
|
||||
assert_graceful_failure(
|
||||
"Stack interrupted initialization: some models registered successfully but others failed"
|
||||
)(e)
|
||||
|
||||
# Verify both registration attempts were made
|
||||
assert mock_impl.register_model.call_count == 2, (
|
||||
"register_model() should be called twice, once for each model, regardless of failures"
|
||||
)
|
||||
|
||||
# Verify the first call succeeded with model1
|
||||
(
|
||||
mock_impl.register_model.assert_any_call(**{k: getattr(model1, k) for k in model1.model_dump().keys()}),
|
||||
"First model registration should be attempted with correct parameters",
|
||||
)
|
||||
|
||||
# Verify the second call was attempted with model2
|
||||
(
|
||||
mock_impl.register_model.assert_any_call(**{k: getattr(model2, k) for k in model2.model_dump().keys()}),
|
||||
"Second model registration should be attempted even after first success",
|
||||
)
|
||||
|
||||
# Verify listing was still performed
|
||||
mock_impl.list_models.assert_called_once(), "list_models() should be called once after all registration attempts"
|
||||
|
||||
# Verify listing returned only the successful model
|
||||
assert len(mock_impl.list_models.return_value) == 1, (
|
||||
"list_models() should return only the successfully registered model"
|
||||
)
|
||||
assert mock_impl.list_models.return_value == [model1], (
|
||||
"list_models() should return the first model that registered successfully"
|
||||
)
|
||||
|
||||
|
||||
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
|
||||
"""Test that disabled providers are skipped."""
|
||||
# Update model to be disabled
|
||||
mock_model.provider_id = "__disabled__"
|
||||
mock_impl = create_mock_impl()
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
try:
|
||||
await register_resources(mock_run_config, impls)
|
||||
except Exception as e:
|
||||
assert_graceful_failure("Stack interrupted initialization: provider is marked as disabled")(e)
|
||||
|
||||
# Should not attempt registration for disabled provider
|
||||
mock_impl.register_model.assert_not_called(), "register_model() should not be called for disabled providers"
|
||||
|
||||
# Should still perform listing
|
||||
mock_impl.list_models.assert_called_once(), "list_models() should still be called even for disabled providers"
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.stack import Stack, register_resources
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
return Model(
|
||||
provider_id="test_provider",
|
||||
provider_model_id="test_model",
|
||||
identifier="test_model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
return Provider(
|
||||
provider_id="test_provider",
|
||||
provider_type="test_type",
|
||||
config={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_run_config(mock_model):
|
||||
return StackRunConfig(
|
||||
image_name="test",
|
||||
apis=["inference"],
|
||||
providers={"inference": [mock_provider]},
|
||||
models=[mock_model],
|
||||
)
|
||||
|
||||
|
||||
async def test_register_resources_success(mock_run_config, mock_model):
|
||||
"""Test successful registration of resources."""
|
||||
mock_impl = AsyncMock()
|
||||
mock_impl.register_model = AsyncMock(return_value=mock_model)
|
||||
mock_impl.list_models = AsyncMock(return_value=[mock_model])
|
||||
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
await register_resources(mock_run_config, impls)
|
||||
|
||||
mock_impl.register_model.assert_called_once()
|
||||
mock_impl.list_models.assert_called_once()
|
||||
|
||||
|
||||
async def test_register_resources_failed_registration(mock_run_config, mock_model):
|
||||
"""Test that stack continues when model registration fails."""
|
||||
mock_impl = AsyncMock()
|
||||
mock_impl.register_model = AsyncMock(side_effect=ValueError("Registration failed"))
|
||||
mock_impl.list_models = AsyncMock(return_value=[])
|
||||
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
# Should not raise exception
|
||||
await register_resources(mock_run_config, impls)
|
||||
|
||||
mock_impl.register_model.assert_called_once()
|
||||
mock_impl.list_models.assert_called_once()
|
||||
|
||||
|
||||
async def test_register_resources_failed_listing(mock_run_config, mock_model):
|
||||
"""Test that stack continues when model listing fails."""
|
||||
mock_impl = AsyncMock()
|
||||
mock_impl.register_model = AsyncMock(return_value=mock_model)
|
||||
mock_impl.list_models = AsyncMock(side_effect=ValueError("Listing failed"))
|
||||
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
# Should not raise exception
|
||||
await register_resources(mock_run_config, impls)
|
||||
|
||||
mock_impl.register_model.assert_called_once()
|
||||
mock_impl.list_models.assert_called_once()
|
||||
|
||||
|
||||
async def test_register_resources_mixed_success(mock_run_config):
|
||||
"""Test mixed success/failure scenario with multiple models."""
|
||||
# Create two models
|
||||
model1 = Model(
|
||||
provider_id="test_provider",
|
||||
provider_model_id="model1",
|
||||
identifier="model1",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
model2 = Model(
|
||||
provider_id="test_provider",
|
||||
provider_model_id="model2",
|
||||
identifier="model2",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
# Update run config to include both models
|
||||
mock_run_config.models = [model1, model2]
|
||||
|
||||
mock_impl = AsyncMock()
|
||||
# Make first registration succeed, second fail
|
||||
mock_impl.register_model = AsyncMock(side_effect=[model1, ValueError("Second registration failed")])
|
||||
mock_impl.list_models = AsyncMock(return_value=[model1]) # Only first model listed
|
||||
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
# Should not raise exception
|
||||
await register_resources(mock_run_config, impls)
|
||||
|
||||
assert mock_impl.register_model.call_count == 2
|
||||
mock_impl.list_models.assert_called_once()
|
||||
|
||||
|
||||
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
|
||||
"""Test that disabled providers are skipped."""
|
||||
# Update model to be disabled
|
||||
mock_model.provider_id = "__disabled__"
|
||||
mock_impl = AsyncMock()
|
||||
|
||||
impls = {Api.models: mock_impl}
|
||||
|
||||
await register_resources(mock_run_config, impls)
|
||||
|
||||
# Should not attempt registration for disabled provider
|
||||
mock_impl.register_model.assert_not_called()
|
||||
mock_impl.list_models.assert_called_once()
|
||||
|
||||
|
||||
class MockFailingProvider:
|
||||
"""A mock provider that fails registration but allows initialization"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.initialize_called = False
|
||||
self.shutdown_called = False
|
||||
|
||||
async def initialize(self):
|
||||
self.initialize_called = True
|
||||
|
||||
async def shutdown(self):
|
||||
self.shutdown_called = True
|
||||
|
||||
async def register_model(self, *args, **kwargs):
|
||||
raise ValueError("Mock registration failure")
|
||||
|
||||
async def list_models(self):
|
||||
return [] # Return empty list to simulate no models registered
|
||||
|
||||
|
||||
async def test_stack_initialization_with_failed_registration():
|
||||
"""Test full stack initialization with failed model registration using a mock provider."""
|
||||
mock_model = Model(
|
||||
provider_id="mock_failing",
|
||||
provider_model_id="test_model",
|
||||
identifier="test_model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
|
||||
mock_run_config = StackRunConfig(
|
||||
image_name="test",
|
||||
apis=["inference"],
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="mock_failing",
|
||||
provider_type="mock::failing",
|
||||
config={}, # No need for real config in mock
|
||||
)
|
||||
]
|
||||
},
|
||||
models=[mock_model],
|
||||
)
|
||||
|
||||
# Create a mock provider registry that returns our failing provider
|
||||
mock_registry = {
|
||||
Api.inference: {
|
||||
"mock::failing": MagicMock(
|
||||
provider_class=MockFailingProvider,
|
||||
config_class=MagicMock(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
stack = Stack(mock_run_config, provider_registry=mock_registry)
|
||||
|
||||
# Should not raise exception during initialization
|
||||
await stack.initialize()
|
||||
|
||||
# Stack should still be initialized
|
||||
assert stack.impls is not None
|
||||
|
||||
# Verify the provider was properly initialized
|
||||
inference_impl = stack.impls.get(Api.inference)
|
||||
assert inference_impl is not None
|
||||
assert inference_impl.initialize_called
|
||||
|
||||
# Clean up
|
||||
await stack.shutdown()
|
||||
assert inference_impl.shutdown_called
|
||||
Loading…
Add table
Add a link
Reference in a new issue