test: add unit tests to verify graceful stack startup

Added tests to verify that the stack:
1. Continues initialization when providers fail to register models
2. Skips invalid models instead of crashing
3. Handles provider listing failures gracefully
4. Maintains partial functionality with mixed success/failure

Example:
- OpenAI provider fails to list models
- Stack logs error and continues with registered models
- Other providers remain functional

This prevents the entire stack from crashing when:
- Provider API keys are invalid
- Models are misconfigured
- Provider API is temporarily unavailable
This commit is contained in:
Akram Ben Aissi 2025-10-10 22:59:18 +02:00
parent a271e3abae
commit a1010e52e9
2 changed files with 304 additions and 207 deletions

View file

@ -0,0 +1,304 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable
from typing import Any
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import Model, ModelInput, ModelType
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.stack import register_resources
from llama_stack.providers.datatypes import Api
@pytest.fixture(autouse=True)
def setup_caplog(caplog):
"""Configure caplog to capture all log levels"""
caplog.set_level("DEBUG")
def assert_graceful_failure(message: str) -> Callable[[Exception], None]:
"""Create a failure handler with consistent error message format.
Args:
message: Error message with optional example
"""
def _fail(_: Exception) -> None:
pytest.fail(message)
return _fail
def create_mock_impl(
register_result: Any = None,
register_error: Exception | None = None,
list_result: Any = None,
list_error: Exception | None = None,
) -> AsyncMock:
"""Create a mock implementation with configurable behavior.
Args:
register_result: Return value for register_model
register_error: Error to raise from register_model
list_result: Return value for list_models
list_error: Error to raise from list_models
"""
mock_impl = AsyncMock()
if register_error:
mock_impl.register_model = AsyncMock(side_effect=register_error)
elif isinstance(register_result, list):
mock_impl.register_model = AsyncMock(side_effect=register_result)
else:
mock_impl.register_model = AsyncMock(return_value=register_result)
if list_error:
mock_impl.list_models = AsyncMock(side_effect=list_error)
else:
mock_impl.list_models = AsyncMock(return_value=list_result or [])
return mock_impl
@pytest.fixture
def mock_model():
"""Create a valid model for testing."""
from llama_stack.apis.models import ModelInput # Import the correct type
return ModelInput( # Use ModelInput instead of Model
provider_id="test_provider",
provider_model_id="test_model",
identifier="test_model",
model_id="test_model", # Required field
model_type=ModelType.llm,
)
@pytest.fixture
def mock_provider():
"""Create a valid provider for testing."""
return Provider(
provider_id="test_provider",
provider_type="remote::test", # Valid format: namespace::name
config={},
)
@pytest.fixture
def mock_run_config(mock_model, mock_provider): # Add mock_provider as dependency
return StackRunConfig(
image_name="test",
apis=["inference"],
providers={"inference": [mock_provider]}, # Use the Provider object directly
models=[mock_model],
)
async def test_register_resources_success(mock_run_config, mock_model):
"""Test successful registration of resources."""
mock_impl = create_mock_impl(
register_result=mock_model,
list_result=[mock_model],
)
impls = {Api.models: mock_impl}
try:
await register_resources(mock_run_config, impls)
except AttributeError:
assert_graceful_failure("Stack interrupted initialization: tried to access model fields in wrong format")(None)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_failed_registration(caplog):
"""Test that stack continues when model registration fails.
This test demonstrates how the stack handles validation errors:
Before fix:
- A provider failing to validate a model would crash the entire stack
- Example: OpenAI provider with invalid API key would prevent startup
- No way to start stack with other working providers
After fix:
- Provider validation errors are caught and logged
- Stack continues initializing with other providers
- Failed provider is skipped but doesn't block others
Test strategy:
1. Create an invalid model (wrong type) to trigger validation
2. Create a valid provider to show it's not provider's fault
3. Verify stack handles validation error and continues
"""
# Create a valid model that will fail registration
invalid_model = ModelInput(
provider_id="test_provider",
provider_model_id="test_model",
identifier="test_model",
model_id="test_model", # Required field
model_type=ModelType.llm,
)
# Create a valid provider
valid_provider = Provider(
provider_id="test_provider",
provider_type="remote::test",
config={},
)
# Create config with the model
mock_run_config = StackRunConfig(
image_name="test",
apis=["inference"],
providers={"inference": [valid_provider]},
models=[invalid_model],
)
mock_impl = create_mock_impl(
register_error=ValueError(
"Provider failed to validate model: expected ModelInput but got Model\n"
"This would previously crash the stack, but should now be handled gracefully"
),
)
impls = {Api.models: mock_impl}
# Before fix: Stack would crash here
# After fix: Should handle error and continue
try:
await register_resources(mock_run_config, impls)
except Exception as e:
assert_graceful_failure("Stack interrupted initialization: provider received model in wrong format")(e)
# Verify registration was attempted despite validation issues
(
mock_impl.register_model.assert_called_once_with(
**{k: getattr(invalid_model, k) for k in invalid_model.model_dump().keys()}
),
"Provider should attempt model registration even with invalid model type",
)
# Verify stack continues operating after validation failure
(
mock_impl.list_models.assert_called_once(),
"Stack should continue normal operation after handling model validation failure",
)
# Verify error was logged
assert "Failed to register models" in caplog.text, "Error should be logged when model registration fails"
assert "Provider failed to validate model" in caplog.text, "Specific error message should be logged"
async def test_register_resources_failed_listing(mock_run_config, mock_model):
"""Test that stack continues when model listing fails."""
mock_impl = create_mock_impl(
register_result=mock_model,
list_error=ValueError("Listing failed"),
)
impls = {Api.models: mock_impl}
# Should not raise exception
try:
await register_resources(mock_run_config, impls)
except Exception as e:
assert_graceful_failure("Stack interrupted initialization: provider failed to list available models")(e)
# Verify registration completed successfully
(
mock_impl.register_model.assert_called_once_with(
**{k: getattr(mock_model, k) for k in mock_model.model_dump().keys()}
),
"register_model() should complete successfully before the listing failure",
)
# Verify listing was attempted
(
mock_impl.list_models.assert_called_once(),
"list_models() should be called and its failure should be handled gracefully",
)
async def test_register_resources_mixed_success(mock_run_config):
"""Test mixed success/failure scenario with multiple models."""
# Create two models
model1 = Model(
provider_id="test_provider",
provider_model_id="model1",
identifier="model1",
model_type=ModelType.llm,
)
model2 = Model(
provider_id="test_provider",
provider_model_id="model2",
identifier="model2",
model_type=ModelType.llm,
)
# Update run config to include both models
mock_run_config.models = [model1, model2]
mock_impl = create_mock_impl(
register_result=[model1, ValueError("Second registration failed")],
list_result=[model1], # Only first model listed
)
impls = {Api.models: mock_impl}
# Should not raise exception
try:
await register_resources(mock_run_config, impls)
except Exception as e:
assert_graceful_failure(
"Stack interrupted initialization: some models registered successfully but others failed"
)(e)
# Verify both registration attempts were made
assert mock_impl.register_model.call_count == 2, (
"register_model() should be called twice, once for each model, regardless of failures"
)
# Verify the first call succeeded with model1
(
mock_impl.register_model.assert_any_call(**{k: getattr(model1, k) for k in model1.model_dump().keys()}),
"First model registration should be attempted with correct parameters",
)
# Verify the second call was attempted with model2
(
mock_impl.register_model.assert_any_call(**{k: getattr(model2, k) for k in model2.model_dump().keys()}),
"Second model registration should be attempted even after first success",
)
# Verify listing was still performed
mock_impl.list_models.assert_called_once(), "list_models() should be called once after all registration attempts"
# Verify listing returned only the successful model
assert len(mock_impl.list_models.return_value) == 1, (
"list_models() should return only the successfully registered model"
)
assert mock_impl.list_models.return_value == [model1], (
"list_models() should return the first model that registered successfully"
)
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
"""Test that disabled providers are skipped."""
# Update model to be disabled
mock_model.provider_id = "__disabled__"
mock_impl = create_mock_impl()
impls = {Api.models: mock_impl}
try:
await register_resources(mock_run_config, impls)
except Exception as e:
assert_graceful_failure("Stack interrupted initialization: provider is marked as disabled")(e)
# Should not attempt registration for disabled provider
mock_impl.register_model.assert_not_called(), "register_model() should not be called for disabled providers"
# Should still perform listing
mock_impl.list_models.assert_called_once(), "list_models() should still be called even for disabled providers"

View file

@ -1,207 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.stack import Stack, register_resources
from llama_stack.providers.datatypes import Api
@pytest.fixture
def mock_model():
return Model(
provider_id="test_provider",
provider_model_id="test_model",
identifier="test_model",
model_type=ModelType.llm,
)
@pytest.fixture
def mock_provider():
return Provider(
provider_id="test_provider",
provider_type="test_type",
config={},
)
@pytest.fixture
def mock_run_config(mock_model):
return StackRunConfig(
image_name="test",
apis=["inference"],
providers={"inference": [mock_provider]},
models=[mock_model],
)
async def test_register_resources_success(mock_run_config, mock_model):
"""Test successful registration of resources."""
mock_impl = AsyncMock()
mock_impl.register_model = AsyncMock(return_value=mock_model)
mock_impl.list_models = AsyncMock(return_value=[mock_model])
impls = {Api.models: mock_impl}
await register_resources(mock_run_config, impls)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_failed_registration(mock_run_config, mock_model):
"""Test that stack continues when model registration fails."""
mock_impl = AsyncMock()
mock_impl.register_model = AsyncMock(side_effect=ValueError("Registration failed"))
mock_impl.list_models = AsyncMock(return_value=[])
impls = {Api.models: mock_impl}
# Should not raise exception
await register_resources(mock_run_config, impls)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_failed_listing(mock_run_config, mock_model):
"""Test that stack continues when model listing fails."""
mock_impl = AsyncMock()
mock_impl.register_model = AsyncMock(return_value=mock_model)
mock_impl.list_models = AsyncMock(side_effect=ValueError("Listing failed"))
impls = {Api.models: mock_impl}
# Should not raise exception
await register_resources(mock_run_config, impls)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_mixed_success(mock_run_config):
"""Test mixed success/failure scenario with multiple models."""
# Create two models
model1 = Model(
provider_id="test_provider",
provider_model_id="model1",
identifier="model1",
model_type=ModelType.llm,
)
model2 = Model(
provider_id="test_provider",
provider_model_id="model2",
identifier="model2",
model_type=ModelType.llm,
)
# Update run config to include both models
mock_run_config.models = [model1, model2]
mock_impl = AsyncMock()
# Make first registration succeed, second fail
mock_impl.register_model = AsyncMock(side_effect=[model1, ValueError("Second registration failed")])
mock_impl.list_models = AsyncMock(return_value=[model1]) # Only first model listed
impls = {Api.models: mock_impl}
# Should not raise exception
await register_resources(mock_run_config, impls)
assert mock_impl.register_model.call_count == 2
mock_impl.list_models.assert_called_once()
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
"""Test that disabled providers are skipped."""
# Update model to be disabled
mock_model.provider_id = "__disabled__"
mock_impl = AsyncMock()
impls = {Api.models: mock_impl}
await register_resources(mock_run_config, impls)
# Should not attempt registration for disabled provider
mock_impl.register_model.assert_not_called()
mock_impl.list_models.assert_called_once()
class MockFailingProvider:
"""A mock provider that fails registration but allows initialization"""
def __init__(self, *args, **kwargs):
self.initialize_called = False
self.shutdown_called = False
async def initialize(self):
self.initialize_called = True
async def shutdown(self):
self.shutdown_called = True
async def register_model(self, *args, **kwargs):
raise ValueError("Mock registration failure")
async def list_models(self):
return [] # Return empty list to simulate no models registered
async def test_stack_initialization_with_failed_registration():
"""Test full stack initialization with failed model registration using a mock provider."""
mock_model = Model(
provider_id="mock_failing",
provider_model_id="test_model",
identifier="test_model",
model_type=ModelType.llm,
)
mock_run_config = StackRunConfig(
image_name="test",
apis=["inference"],
providers={
"inference": [
Provider(
provider_id="mock_failing",
provider_type="mock::failing",
config={}, # No need for real config in mock
)
]
},
models=[mock_model],
)
# Create a mock provider registry that returns our failing provider
mock_registry = {
Api.inference: {
"mock::failing": MagicMock(
provider_class=MockFailingProvider,
config_class=MagicMock(),
)
}
}
stack = Stack(mock_run_config, provider_registry=mock_registry)
# Should not raise exception during initialization
await stack.initialize()
# Stack should still be initialized
assert stack.impls is not None
# Verify the provider was properly initialized
inference_impl = stack.impls.get(Api.inference)
assert inference_impl is not None
assert inference_impl.initialize_called
# Clean up
await stack.shutdown()
assert inference_impl.shutdown_called