mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 21:19:27 +00:00
fix: handle provider registration failures gracefully
When a provider fails during model registration or listing, the stack should continue initializing rather than crashing. This allows the stack to start even if some providers are misconfigured. - Added error handling in register_resources() - Added unit tests to verify error handling behavior - Improved error logging with provider context - Removed @pytest.mark.asyncio decorators (pytest already configured with async-mode=auto) Fixes #3769
This commit is contained in:
parent
e7d21e1ee3
commit
a271e3abae
2 changed files with 235 additions and 11 deletions
|
|
@ -102,6 +102,12 @@ TEST_RECORDING_CONTEXT = None
|
||||||
|
|
||||||
|
|
||||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
|
"""Register resources from the run config with their respective providers.
|
||||||
|
|
||||||
|
This function attempts to register each resource (models, shields, etc.) with its provider.
|
||||||
|
If a registration fails, it logs the error and continues with other resources rather than
|
||||||
|
crashing the entire stack.
|
||||||
|
"""
|
||||||
for rsrc, api, register_method, list_method in RESOURCES:
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
|
|
@ -116,20 +122,31 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
continue
|
continue
|
||||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||||
|
|
||||||
# we want to maintain the type information in arguments to method.
|
try:
|
||||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
# we want to maintain the type information in arguments to method.
|
||||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||||
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||||
|
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but continue with other resources
|
||||||
|
logger.error(
|
||||||
|
f"Failed to register {rsrc} {obj} for provider {obj.provider_id if hasattr(obj, 'provider_id') else 'unknown'}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
method = getattr(impls[api], list_method)
|
try:
|
||||||
response = await method()
|
method = getattr(impls[api], list_method)
|
||||||
|
response = await method()
|
||||||
|
|
||||||
objects_to_process = response.data if hasattr(response, "data") else response
|
objects_to_process = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
for obj in objects_to_process:
|
for obj in objects_to_process:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but continue with other resource types
|
||||||
|
logger.error(f"Failed to list {rsrc}: {e}")
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
|
|
|
||||||
207
tests/unit/test_stack.py
Normal file
207
tests/unit/test_stack.py
Normal file
|
|
@ -0,0 +1,207 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.core.stack import Stack, register_resources
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model():
|
||||||
|
return Model(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_model_id="test_model",
|
||||||
|
identifier="test_model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_provider():
|
||||||
|
return Provider(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_type="test_type",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_run_config(mock_model):
|
||||||
|
return StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
apis=["inference"],
|
||||||
|
providers={"inference": [mock_provider]},
|
||||||
|
models=[mock_model],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_resources_success(mock_run_config, mock_model):
|
||||||
|
"""Test successful registration of resources."""
|
||||||
|
mock_impl = AsyncMock()
|
||||||
|
mock_impl.register_model = AsyncMock(return_value=mock_model)
|
||||||
|
mock_impl.list_models = AsyncMock(return_value=[mock_model])
|
||||||
|
|
||||||
|
impls = {Api.models: mock_impl}
|
||||||
|
|
||||||
|
await register_resources(mock_run_config, impls)
|
||||||
|
|
||||||
|
mock_impl.register_model.assert_called_once()
|
||||||
|
mock_impl.list_models.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_resources_failed_registration(mock_run_config, mock_model):
|
||||||
|
"""Test that stack continues when model registration fails."""
|
||||||
|
mock_impl = AsyncMock()
|
||||||
|
mock_impl.register_model = AsyncMock(side_effect=ValueError("Registration failed"))
|
||||||
|
mock_impl.list_models = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
impls = {Api.models: mock_impl}
|
||||||
|
|
||||||
|
# Should not raise exception
|
||||||
|
await register_resources(mock_run_config, impls)
|
||||||
|
|
||||||
|
mock_impl.register_model.assert_called_once()
|
||||||
|
mock_impl.list_models.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_resources_failed_listing(mock_run_config, mock_model):
|
||||||
|
"""Test that stack continues when model listing fails."""
|
||||||
|
mock_impl = AsyncMock()
|
||||||
|
mock_impl.register_model = AsyncMock(return_value=mock_model)
|
||||||
|
mock_impl.list_models = AsyncMock(side_effect=ValueError("Listing failed"))
|
||||||
|
|
||||||
|
impls = {Api.models: mock_impl}
|
||||||
|
|
||||||
|
# Should not raise exception
|
||||||
|
await register_resources(mock_run_config, impls)
|
||||||
|
|
||||||
|
mock_impl.register_model.assert_called_once()
|
||||||
|
mock_impl.list_models.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_resources_mixed_success(mock_run_config):
|
||||||
|
"""Test mixed success/failure scenario with multiple models."""
|
||||||
|
# Create two models
|
||||||
|
model1 = Model(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_model_id="model1",
|
||||||
|
identifier="model1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
model2 = Model(
|
||||||
|
provider_id="test_provider",
|
||||||
|
provider_model_id="model2",
|
||||||
|
identifier="model2",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update run config to include both models
|
||||||
|
mock_run_config.models = [model1, model2]
|
||||||
|
|
||||||
|
mock_impl = AsyncMock()
|
||||||
|
# Make first registration succeed, second fail
|
||||||
|
mock_impl.register_model = AsyncMock(side_effect=[model1, ValueError("Second registration failed")])
|
||||||
|
mock_impl.list_models = AsyncMock(return_value=[model1]) # Only first model listed
|
||||||
|
|
||||||
|
impls = {Api.models: mock_impl}
|
||||||
|
|
||||||
|
# Should not raise exception
|
||||||
|
await register_resources(mock_run_config, impls)
|
||||||
|
|
||||||
|
assert mock_impl.register_model.call_count == 2
|
||||||
|
mock_impl.list_models.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_register_resources_disabled_provider(mock_run_config, mock_model):
|
||||||
|
"""Test that disabled providers are skipped."""
|
||||||
|
# Update model to be disabled
|
||||||
|
mock_model.provider_id = "__disabled__"
|
||||||
|
mock_impl = AsyncMock()
|
||||||
|
|
||||||
|
impls = {Api.models: mock_impl}
|
||||||
|
|
||||||
|
await register_resources(mock_run_config, impls)
|
||||||
|
|
||||||
|
# Should not attempt registration for disabled provider
|
||||||
|
mock_impl.register_model.assert_not_called()
|
||||||
|
mock_impl.list_models.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class MockFailingProvider:
|
||||||
|
"""A mock provider that fails registration but allows initialization"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.initialize_called = False
|
||||||
|
self.shutdown_called = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
self.initialize_called = True
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
self.shutdown_called = True
|
||||||
|
|
||||||
|
async def register_model(self, *args, **kwargs):
|
||||||
|
raise ValueError("Mock registration failure")
|
||||||
|
|
||||||
|
async def list_models(self):
|
||||||
|
return [] # Return empty list to simulate no models registered
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stack_initialization_with_failed_registration():
|
||||||
|
"""Test full stack initialization with failed model registration using a mock provider."""
|
||||||
|
mock_model = Model(
|
||||||
|
provider_id="mock_failing",
|
||||||
|
provider_model_id="test_model",
|
||||||
|
identifier="test_model",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_run_config = StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
apis=["inference"],
|
||||||
|
providers={
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_id="mock_failing",
|
||||||
|
provider_type="mock::failing",
|
||||||
|
config={}, # No need for real config in mock
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
models=[mock_model],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock provider registry that returns our failing provider
|
||||||
|
mock_registry = {
|
||||||
|
Api.inference: {
|
||||||
|
"mock::failing": MagicMock(
|
||||||
|
provider_class=MockFailingProvider,
|
||||||
|
config_class=MagicMock(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stack = Stack(mock_run_config, provider_registry=mock_registry)
|
||||||
|
|
||||||
|
# Should not raise exception during initialization
|
||||||
|
await stack.initialize()
|
||||||
|
|
||||||
|
# Stack should still be initialized
|
||||||
|
assert stack.impls is not None
|
||||||
|
|
||||||
|
# Verify the provider was properly initialized
|
||||||
|
inference_impl = stack.impls.get(Api.inference)
|
||||||
|
assert inference_impl is not None
|
||||||
|
assert inference_impl.initialize_called
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await stack.shutdown()
|
||||||
|
assert inference_impl.shutdown_called
|
||||||
Loading…
Add table
Add a link
Reference in a new issue