llama-stack-mirror/tests/unit/test_stack.py
Akram Ben Aissi a271e3abae fix: handle provider registration failures gracefully
When a provider fails during model registration or listing, the stack
should continue initializing rather than crashing. This allows the
stack to start even if some providers are misconfigured.

- Added error handling in register_resources()
- Added unit tests to verify error handling behavior
- Improved error logging with provider context
- Removed @pytest.mark.asyncio decorators (pytest already configured with async-mode=auto)

Fixes #3769
2025-10-10 22:19:29 +02:00

207 lines
6.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.stack import Stack, register_resources
from llama_stack.providers.datatypes import Api
@pytest.fixture
def mock_model():
return Model(
provider_id="test_provider",
provider_model_id="test_model",
identifier="test_model",
model_type=ModelType.llm,
)
@pytest.fixture
def mock_provider():
return Provider(
provider_id="test_provider",
provider_type="test_type",
config={},
)
@pytest.fixture
def mock_run_config(mock_model):
return StackRunConfig(
image_name="test",
apis=["inference"],
providers={"inference": [mock_provider]},
models=[mock_model],
)
async def test_register_resources_success(mock_run_config, mock_model):
"""Test successful registration of resources."""
mock_impl = AsyncMock()
mock_impl.register_model = AsyncMock(return_value=mock_model)
mock_impl.list_models = AsyncMock(return_value=[mock_model])
impls = {Api.models: mock_impl}
await register_resources(mock_run_config, impls)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_failed_registration(mock_run_config, mock_model):
"""Test that stack continues when model registration fails."""
mock_impl = AsyncMock()
mock_impl.register_model = AsyncMock(side_effect=ValueError("Registration failed"))
mock_impl.list_models = AsyncMock(return_value=[])
impls = {Api.models: mock_impl}
# Should not raise exception
await register_resources(mock_run_config, impls)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_failed_listing(mock_run_config, mock_model):
"""Test that stack continues when model listing fails."""
mock_impl = AsyncMock()
mock_impl.register_model = AsyncMock(return_value=mock_model)
mock_impl.list_models = AsyncMock(side_effect=ValueError("Listing failed"))
impls = {Api.models: mock_impl}
# Should not raise exception
await register_resources(mock_run_config, impls)
mock_impl.register_model.assert_called_once()
mock_impl.list_models.assert_called_once()
async def test_register_resources_mixed_success(mock_run_config):
"""Test mixed success/failure scenario with multiple models."""
# Create two models
model1 = Model(
provider_id="test_provider",
provider_model_id="model1",
identifier="model1",
model_type=ModelType.llm,
)
model2 = Model(
provider_id="test_provider",
provider_model_id="model2",
identifier="model2",
model_type=ModelType.llm,
)
# Update run config to include both models
mock_run_config.models = [model1, model2]
mock_impl = AsyncMock()
# Make first registration succeed, second fail
mock_impl.register_model = AsyncMock(side_effect=[model1, ValueError("Second registration failed")])
mock_impl.list_models = AsyncMock(return_value=[model1]) # Only first model listed
impls = {Api.models: mock_impl}
# Should not raise exception
await register_resources(mock_run_config, impls)
assert mock_impl.register_model.call_count == 2
mock_impl.list_models.assert_called_once()
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
"""Test that disabled providers are skipped."""
# Update model to be disabled
mock_model.provider_id = "__disabled__"
mock_impl = AsyncMock()
impls = {Api.models: mock_impl}
await register_resources(mock_run_config, impls)
# Should not attempt registration for disabled provider
mock_impl.register_model.assert_not_called()
mock_impl.list_models.assert_called_once()
class MockFailingProvider:
"""A mock provider that fails registration but allows initialization"""
def __init__(self, *args, **kwargs):
self.initialize_called = False
self.shutdown_called = False
async def initialize(self):
self.initialize_called = True
async def shutdown(self):
self.shutdown_called = True
async def register_model(self, *args, **kwargs):
raise ValueError("Mock registration failure")
async def list_models(self):
return [] # Return empty list to simulate no models registered
async def test_stack_initialization_with_failed_registration():
"""Test full stack initialization with failed model registration using a mock provider."""
mock_model = Model(
provider_id="mock_failing",
provider_model_id="test_model",
identifier="test_model",
model_type=ModelType.llm,
)
mock_run_config = StackRunConfig(
image_name="test",
apis=["inference"],
providers={
"inference": [
Provider(
provider_id="mock_failing",
provider_type="mock::failing",
config={}, # No need for real config in mock
)
]
},
models=[mock_model],
)
# Create a mock provider registry that returns our failing provider
mock_registry = {
Api.inference: {
"mock::failing": MagicMock(
provider_class=MockFailingProvider,
config_class=MagicMock(),
)
}
}
stack = Stack(mock_run_config, provider_registry=mock_registry)
# Should not raise exception during initialization
await stack.initialize()
# Stack should still be initialized
assert stack.impls is not None
# Verify the provider was properly initialized
inference_impl = stack.impls.get(Api.inference)
assert inference_impl is not None
assert inference_impl.initialize_called
# Clean up
await stack.shutdown()
assert inference_impl.shutdown_called