diff --git a/tests/unit/core/test_stack.py b/tests/unit/core/test_stack.py new file mode 100644 index 000000000..05a98904c --- /dev/null +++ b/tests/unit/core/test_stack.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Callable +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.models import Model, ModelInput, ModelType +from llama_stack.core.datatypes import Provider, StackRunConfig +from llama_stack.core.stack import register_resources +from llama_stack.providers.datatypes import Api + + +@pytest.fixture(autouse=True) +def setup_caplog(caplog): + """Configure caplog to capture all log levels""" + caplog.set_level("DEBUG") + + +def assert_graceful_failure(message: str) -> Callable[[Exception], None]: + """Create a failure handler with consistent error message format. + + Args: + message: Error message with optional example + """ + + def _fail(_: Exception) -> None: + pytest.fail(message) + + return _fail + + +def create_mock_impl( + register_result: Any = None, + register_error: Exception | None = None, + list_result: Any = None, + list_error: Exception | None = None, +) -> AsyncMock: + """Create a mock implementation with configurable behavior. + + Args: + register_result: Return value for register_model + register_error: Error to raise from register_model + list_result: Return value for list_models + list_error: Error to raise from list_models + """ + mock_impl = AsyncMock() + if register_error: + mock_impl.register_model = AsyncMock(side_effect=register_error) + elif isinstance(register_result, list): + mock_impl.register_model = AsyncMock(side_effect=register_result) + else: + mock_impl.register_model = AsyncMock(return_value=register_result) + + if list_error: + mock_impl.list_models = AsyncMock(side_effect=list_error) + else: + mock_impl.list_models = AsyncMock(return_value=list_result or []) + + return mock_impl + + +@pytest.fixture +def mock_model(): + """Create a valid model for testing.""" + from llama_stack.apis.models import ModelInput # Import the correct type + + return ModelInput( # Use ModelInput instead of Model + provider_id="test_provider", + provider_model_id="test_model", + identifier="test_model", + model_id="test_model", # Required field + model_type=ModelType.llm, + ) + + +@pytest.fixture +def mock_provider(): + """Create a valid provider for testing.""" + return Provider( + provider_id="test_provider", + provider_type="remote::test", # Valid format: namespace::name + config={}, + ) + + +@pytest.fixture +def mock_run_config(mock_model, mock_provider): # Add mock_provider as dependency + return StackRunConfig( + image_name="test", + apis=["inference"], + providers={"inference": [mock_provider]}, # Use the Provider object directly + models=[mock_model], + ) + + +async def test_register_resources_success(mock_run_config, mock_model): + """Test successful registration of resources.""" + mock_impl = create_mock_impl( + register_result=mock_model, + list_result=[mock_model], + ) + impls = {Api.models: mock_impl} + + try: + await register_resources(mock_run_config, impls) + except AttributeError: + assert_graceful_failure("Stack interrupted initialization: tried to access model fields in wrong format")(None) + + mock_impl.register_model.assert_called_once() + mock_impl.list_models.assert_called_once() + + +async def test_register_resources_failed_registration(caplog): + """Test that stack continues when model registration fails. + + This test demonstrates how the stack handles validation errors: + + Before fix: + - A provider failing to validate a model would crash the entire stack + - Example: OpenAI provider with invalid API key would prevent startup + - No way to start stack with other working providers + + After fix: + - Provider validation errors are caught and logged + - Stack continues initializing with other providers + - Failed provider is skipped but doesn't block others + + Test strategy: + 1. Create an invalid model (wrong type) to trigger validation + 2. Create a valid provider to show it's not provider's fault + 3. Verify stack handles validation error and continues + """ + # Create a valid model that will fail registration + invalid_model = ModelInput( + provider_id="test_provider", + provider_model_id="test_model", + identifier="test_model", + model_id="test_model", # Required field + model_type=ModelType.llm, + ) + + # Create a valid provider + valid_provider = Provider( + provider_id="test_provider", + provider_type="remote::test", + config={}, + ) + + # Create config with the model + mock_run_config = StackRunConfig( + image_name="test", + apis=["inference"], + providers={"inference": [valid_provider]}, + models=[invalid_model], + ) + + mock_impl = create_mock_impl( + register_error=ValueError( + "Provider failed to validate model: expected ModelInput but got Model\n" + "This would previously crash the stack, but should now be handled gracefully" + ), + ) + impls = {Api.models: mock_impl} + + # Before fix: Stack would crash here + # After fix: Should handle error and continue + try: + await register_resources(mock_run_config, impls) + except Exception as e: + assert_graceful_failure("Stack interrupted initialization: provider received model in wrong format")(e) + + # Verify registration was attempted despite validation issues + ( + mock_impl.register_model.assert_called_once_with( + **{k: getattr(invalid_model, k) for k in invalid_model.model_dump().keys()} + ), + "Provider should attempt model registration even with invalid model type", + ) + + # Verify stack continues operating after validation failure + ( + mock_impl.list_models.assert_called_once(), + "Stack should continue normal operation after handling model validation failure", + ) + + # Verify error was logged + assert "Failed to register models" in caplog.text, "Error should be logged when model registration fails" + assert "Provider failed to validate model" in caplog.text, "Specific error message should be logged" + + +async def test_register_resources_failed_listing(mock_run_config, mock_model): + """Test that stack continues when model listing fails.""" + mock_impl = create_mock_impl( + register_result=mock_model, + list_error=ValueError("Listing failed"), + ) + impls = {Api.models: mock_impl} + + # Should not raise exception + try: + await register_resources(mock_run_config, impls) + except Exception as e: + assert_graceful_failure("Stack interrupted initialization: provider failed to list available models")(e) + + # Verify registration completed successfully + ( + mock_impl.register_model.assert_called_once_with( + **{k: getattr(mock_model, k) for k in mock_model.model_dump().keys()} + ), + "register_model() should complete successfully before the listing failure", + ) + + # Verify listing was attempted + ( + mock_impl.list_models.assert_called_once(), + "list_models() should be called and its failure should be handled gracefully", + ) + + +async def test_register_resources_mixed_success(mock_run_config): + """Test mixed success/failure scenario with multiple models.""" + # Create two models + model1 = Model( + provider_id="test_provider", + provider_model_id="model1", + identifier="model1", + model_type=ModelType.llm, + ) + model2 = Model( + provider_id="test_provider", + provider_model_id="model2", + identifier="model2", + model_type=ModelType.llm, + ) + + # Update run config to include both models + mock_run_config.models = [model1, model2] + + mock_impl = create_mock_impl( + register_result=[model1, ValueError("Second registration failed")], + list_result=[model1], # Only first model listed + ) + impls = {Api.models: mock_impl} + + # Should not raise exception + try: + await register_resources(mock_run_config, impls) + except Exception as e: + assert_graceful_failure( + "Stack interrupted initialization: some models registered successfully but others failed" + )(e) + + # Verify both registration attempts were made + assert mock_impl.register_model.call_count == 2, ( + "register_model() should be called twice, once for each model, regardless of failures" + ) + + # Verify the first call succeeded with model1 + ( + mock_impl.register_model.assert_any_call(**{k: getattr(model1, k) for k in model1.model_dump().keys()}), + "First model registration should be attempted with correct parameters", + ) + + # Verify the second call was attempted with model2 + ( + mock_impl.register_model.assert_any_call(**{k: getattr(model2, k) for k in model2.model_dump().keys()}), + "Second model registration should be attempted even after first success", + ) + + # Verify listing was still performed + mock_impl.list_models.assert_called_once(), "list_models() should be called once after all registration attempts" + + # Verify listing returned only the successful model + assert len(mock_impl.list_models.return_value) == 1, ( + "list_models() should return only the successfully registered model" + ) + assert mock_impl.list_models.return_value == [model1], ( + "list_models() should return the first model that registered successfully" + ) + + +async def test_register_resources_disabled_provider(mock_run_config, mock_model): + """Test that disabled providers are skipped.""" + # Update model to be disabled + mock_model.provider_id = "__disabled__" + mock_impl = create_mock_impl() + impls = {Api.models: mock_impl} + + try: + await register_resources(mock_run_config, impls) + except Exception as e: + assert_graceful_failure("Stack interrupted initialization: provider is marked as disabled")(e) + + # Should not attempt registration for disabled provider + mock_impl.register_model.assert_not_called(), "register_model() should not be called for disabled providers" + + # Should still perform listing + mock_impl.list_models.assert_called_once(), "list_models() should still be called even for disabled providers" diff --git a/tests/unit/test_stack.py b/tests/unit/test_stack.py deleted file mode 100644 index 768bd34c5..000000000 --- a/tests/unit/test_stack.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from llama_stack.apis.models import Model, ModelType -from llama_stack.core.datatypes import Provider, StackRunConfig -from llama_stack.core.stack import Stack, register_resources -from llama_stack.providers.datatypes import Api - - -@pytest.fixture -def mock_model(): - return Model( - provider_id="test_provider", - provider_model_id="test_model", - identifier="test_model", - model_type=ModelType.llm, - ) - - -@pytest.fixture -def mock_provider(): - return Provider( - provider_id="test_provider", - provider_type="test_type", - config={}, - ) - - -@pytest.fixture -def mock_run_config(mock_model): - return StackRunConfig( - image_name="test", - apis=["inference"], - providers={"inference": [mock_provider]}, - models=[mock_model], - ) - - -async def test_register_resources_success(mock_run_config, mock_model): - """Test successful registration of resources.""" - mock_impl = AsyncMock() - mock_impl.register_model = AsyncMock(return_value=mock_model) - mock_impl.list_models = AsyncMock(return_value=[mock_model]) - - impls = {Api.models: mock_impl} - - await register_resources(mock_run_config, impls) - - mock_impl.register_model.assert_called_once() - mock_impl.list_models.assert_called_once() - - -async def test_register_resources_failed_registration(mock_run_config, mock_model): - """Test that stack continues when model registration fails.""" - mock_impl = AsyncMock() - mock_impl.register_model = AsyncMock(side_effect=ValueError("Registration failed")) - mock_impl.list_models = AsyncMock(return_value=[]) - - impls = {Api.models: mock_impl} - - # Should not raise exception - await register_resources(mock_run_config, impls) - - mock_impl.register_model.assert_called_once() - mock_impl.list_models.assert_called_once() - - -async def test_register_resources_failed_listing(mock_run_config, mock_model): - """Test that stack continues when model listing fails.""" - mock_impl = AsyncMock() - mock_impl.register_model = AsyncMock(return_value=mock_model) - mock_impl.list_models = AsyncMock(side_effect=ValueError("Listing failed")) - - impls = {Api.models: mock_impl} - - # Should not raise exception - await register_resources(mock_run_config, impls) - - mock_impl.register_model.assert_called_once() - mock_impl.list_models.assert_called_once() - - -async def test_register_resources_mixed_success(mock_run_config): - """Test mixed success/failure scenario with multiple models.""" - # Create two models - model1 = Model( - provider_id="test_provider", - provider_model_id="model1", - identifier="model1", - model_type=ModelType.llm, - ) - model2 = Model( - provider_id="test_provider", - provider_model_id="model2", - identifier="model2", - model_type=ModelType.llm, - ) - - # Update run config to include both models - mock_run_config.models = [model1, model2] - - mock_impl = AsyncMock() - # Make first registration succeed, second fail - mock_impl.register_model = AsyncMock(side_effect=[model1, ValueError("Second registration failed")]) - mock_impl.list_models = AsyncMock(return_value=[model1]) # Only first model listed - - impls = {Api.models: mock_impl} - - # Should not raise exception - await register_resources(mock_run_config, impls) - - assert mock_impl.register_model.call_count == 2 - mock_impl.list_models.assert_called_once() - - -async def test_register_resources_disabled_provider(mock_run_config, mock_model): - """Test that disabled providers are skipped.""" - # Update model to be disabled - mock_model.provider_id = "__disabled__" - mock_impl = AsyncMock() - - impls = {Api.models: mock_impl} - - await register_resources(mock_run_config, impls) - - # Should not attempt registration for disabled provider - mock_impl.register_model.assert_not_called() - mock_impl.list_models.assert_called_once() - - -class MockFailingProvider: - """A mock provider that fails registration but allows initialization""" - - def __init__(self, *args, **kwargs): - self.initialize_called = False - self.shutdown_called = False - - async def initialize(self): - self.initialize_called = True - - async def shutdown(self): - self.shutdown_called = True - - async def register_model(self, *args, **kwargs): - raise ValueError("Mock registration failure") - - async def list_models(self): - return [] # Return empty list to simulate no models registered - - -async def test_stack_initialization_with_failed_registration(): - """Test full stack initialization with failed model registration using a mock provider.""" - mock_model = Model( - provider_id="mock_failing", - provider_model_id="test_model", - identifier="test_model", - model_type=ModelType.llm, - ) - - mock_run_config = StackRunConfig( - image_name="test", - apis=["inference"], - providers={ - "inference": [ - Provider( - provider_id="mock_failing", - provider_type="mock::failing", - config={}, # No need for real config in mock - ) - ] - }, - models=[mock_model], - ) - - # Create a mock provider registry that returns our failing provider - mock_registry = { - Api.inference: { - "mock::failing": MagicMock( - provider_class=MockFailingProvider, - config_class=MagicMock(), - ) - } - } - - stack = Stack(mock_run_config, provider_registry=mock_registry) - - # Should not raise exception during initialization - await stack.initialize() - - # Stack should still be initialized - assert stack.impls is not None - - # Verify the provider was properly initialized - inference_impl = stack.impls.get(Api.inference) - assert inference_impl is not None - assert inference_impl.initialize_called - - # Clean up - await stack.shutdown() - assert inference_impl.shutdown_called