mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 08:32:44 +00:00
When a provider fails during model registration or listing, the stack should continue initializing rather than crashing. This allows the stack to start even if some providers are misconfigured. - Added error handling in register_resources() - Added unit tests to verify error handling behavior - Improved error logging with provider context - Removed @pytest.mark.asyncio decorators (pytest already configured with async-mode=auto) Fixes #3769
207 lines
6.2 KiB
Python
207 lines
6.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
|
from llama_stack.core.stack import Stack, register_resources
|
|
from llama_stack.providers.datatypes import Api
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model():
|
|
return Model(
|
|
provider_id="test_provider",
|
|
provider_model_id="test_model",
|
|
identifier="test_model",
|
|
model_type=ModelType.llm,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_provider():
|
|
return Provider(
|
|
provider_id="test_provider",
|
|
provider_type="test_type",
|
|
config={},
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_run_config(mock_model):
|
|
return StackRunConfig(
|
|
image_name="test",
|
|
apis=["inference"],
|
|
providers={"inference": [mock_provider]},
|
|
models=[mock_model],
|
|
)
|
|
|
|
|
|
async def test_register_resources_success(mock_run_config, mock_model):
|
|
"""Test successful registration of resources."""
|
|
mock_impl = AsyncMock()
|
|
mock_impl.register_model = AsyncMock(return_value=mock_model)
|
|
mock_impl.list_models = AsyncMock(return_value=[mock_model])
|
|
|
|
impls = {Api.models: mock_impl}
|
|
|
|
await register_resources(mock_run_config, impls)
|
|
|
|
mock_impl.register_model.assert_called_once()
|
|
mock_impl.list_models.assert_called_once()
|
|
|
|
|
|
async def test_register_resources_failed_registration(mock_run_config, mock_model):
|
|
"""Test that stack continues when model registration fails."""
|
|
mock_impl = AsyncMock()
|
|
mock_impl.register_model = AsyncMock(side_effect=ValueError("Registration failed"))
|
|
mock_impl.list_models = AsyncMock(return_value=[])
|
|
|
|
impls = {Api.models: mock_impl}
|
|
|
|
# Should not raise exception
|
|
await register_resources(mock_run_config, impls)
|
|
|
|
mock_impl.register_model.assert_called_once()
|
|
mock_impl.list_models.assert_called_once()
|
|
|
|
|
|
async def test_register_resources_failed_listing(mock_run_config, mock_model):
|
|
"""Test that stack continues when model listing fails."""
|
|
mock_impl = AsyncMock()
|
|
mock_impl.register_model = AsyncMock(return_value=mock_model)
|
|
mock_impl.list_models = AsyncMock(side_effect=ValueError("Listing failed"))
|
|
|
|
impls = {Api.models: mock_impl}
|
|
|
|
# Should not raise exception
|
|
await register_resources(mock_run_config, impls)
|
|
|
|
mock_impl.register_model.assert_called_once()
|
|
mock_impl.list_models.assert_called_once()
|
|
|
|
|
|
async def test_register_resources_mixed_success(mock_run_config):
|
|
"""Test mixed success/failure scenario with multiple models."""
|
|
# Create two models
|
|
model1 = Model(
|
|
provider_id="test_provider",
|
|
provider_model_id="model1",
|
|
identifier="model1",
|
|
model_type=ModelType.llm,
|
|
)
|
|
model2 = Model(
|
|
provider_id="test_provider",
|
|
provider_model_id="model2",
|
|
identifier="model2",
|
|
model_type=ModelType.llm,
|
|
)
|
|
|
|
# Update run config to include both models
|
|
mock_run_config.models = [model1, model2]
|
|
|
|
mock_impl = AsyncMock()
|
|
# Make first registration succeed, second fail
|
|
mock_impl.register_model = AsyncMock(side_effect=[model1, ValueError("Second registration failed")])
|
|
mock_impl.list_models = AsyncMock(return_value=[model1]) # Only first model listed
|
|
|
|
impls = {Api.models: mock_impl}
|
|
|
|
# Should not raise exception
|
|
await register_resources(mock_run_config, impls)
|
|
|
|
assert mock_impl.register_model.call_count == 2
|
|
mock_impl.list_models.assert_called_once()
|
|
|
|
|
|
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
|
|
"""Test that disabled providers are skipped."""
|
|
# Update model to be disabled
|
|
mock_model.provider_id = "__disabled__"
|
|
mock_impl = AsyncMock()
|
|
|
|
impls = {Api.models: mock_impl}
|
|
|
|
await register_resources(mock_run_config, impls)
|
|
|
|
# Should not attempt registration for disabled provider
|
|
mock_impl.register_model.assert_not_called()
|
|
mock_impl.list_models.assert_called_once()
|
|
|
|
|
|
class MockFailingProvider:
|
|
"""A mock provider that fails registration but allows initialization"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.initialize_called = False
|
|
self.shutdown_called = False
|
|
|
|
async def initialize(self):
|
|
self.initialize_called = True
|
|
|
|
async def shutdown(self):
|
|
self.shutdown_called = True
|
|
|
|
async def register_model(self, *args, **kwargs):
|
|
raise ValueError("Mock registration failure")
|
|
|
|
async def list_models(self):
|
|
return [] # Return empty list to simulate no models registered
|
|
|
|
|
|
async def test_stack_initialization_with_failed_registration():
|
|
"""Test full stack initialization with failed model registration using a mock provider."""
|
|
mock_model = Model(
|
|
provider_id="mock_failing",
|
|
provider_model_id="test_model",
|
|
identifier="test_model",
|
|
model_type=ModelType.llm,
|
|
)
|
|
|
|
mock_run_config = StackRunConfig(
|
|
image_name="test",
|
|
apis=["inference"],
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="mock_failing",
|
|
provider_type="mock::failing",
|
|
config={}, # No need for real config in mock
|
|
)
|
|
]
|
|
},
|
|
models=[mock_model],
|
|
)
|
|
|
|
# Create a mock provider registry that returns our failing provider
|
|
mock_registry = {
|
|
Api.inference: {
|
|
"mock::failing": MagicMock(
|
|
provider_class=MockFailingProvider,
|
|
config_class=MagicMock(),
|
|
)
|
|
}
|
|
}
|
|
|
|
stack = Stack(mock_run_config, provider_registry=mock_registry)
|
|
|
|
# Should not raise exception during initialization
|
|
await stack.initialize()
|
|
|
|
# Stack should still be initialized
|
|
assert stack.impls is not None
|
|
|
|
# Verify the provider was properly initialized
|
|
inference_impl = stack.impls.get(Api.inference)
|
|
assert inference_impl is not None
|
|
assert inference_impl.initialize_called
|
|
|
|
# Clean up
|
|
await stack.shutdown()
|
|
assert inference_impl.shutdown_called
|