diff --git a/src/llama_stack/providers/inline/agents/meta_reference/__init__.py b/src/llama_stack/providers/inline/agents/meta_reference/__init__.py index 91287617a..b3fb814e3 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -23,7 +23,7 @@ async def get_provider_impl( config, deps[Api.inference], deps[Api.vector_io], - deps[Api.safety], + deps.get(Api.safety), deps[Api.tool_runtime], deps[Api.tool_groups], deps[Api.conversations], diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index ba83a9576..2d5aa6c04 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -41,7 +41,7 @@ class MetaReferenceAgentsImpl(Agents): config: MetaReferenceAgentsImplConfig, inference_api: Inference, vector_io_api: VectorIO, - safety_api: Safety, + safety_api: Safety | None, tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, conversations_api: Conversations, diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 7e080a675..11bfb1417 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -67,7 +67,7 @@ class OpenAIResponsesImpl: tool_runtime_api: ToolRuntime, responses_store: ResponsesStore, vector_io_api: VectorIO, # VectorIO - safety_api: Safety, + safety_api: Safety | None, conversations_api: Conversations, ): self.inference_api = inference_api @@ -273,6 +273,14 @@ class OpenAIResponsesImpl: guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else [] + # Validate that Safety API is available if guardrails are requested + if guardrail_ids and self.safety_api is None: + raise ValueError( + "Cannot process guardrails: Safety API is not configured.\n\n" + "To use guardrails, ensure the Safety API is configured in your stack, or remove " + "the 'guardrails' parameter from your request." + ) + if conversation is not None: if previous_response_id is not None: raise ValueError( diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index cdbd87244..0ef74f1f1 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -66,6 +66,7 @@ from llama_stack_api import ( OpenAIResponseUsage, OpenAIResponseUsageInputTokensDetails, OpenAIResponseUsageOutputTokensDetails, + Safety, WebSearchToolTypes, ) @@ -111,7 +112,7 @@ class StreamingResponseOrchestrator: max_infer_iters: int, tool_executor, # Will be the tool execution logic from the main class instructions: str | None, - safety_api, + safety_api: Safety | None, guardrail_ids: list[str] | None = None, prompt: OpenAIResponsePrompt | None = None, parallel_tool_calls: bool | None = None, diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 943bbae41..25460bcfe 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -320,11 +320,15 @@ def is_function_tool_call( return False -async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None: +async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None: """Run guardrails against messages and return violation message if blocked.""" if not messages: return None + # If safety API is not available, skip guardrails + if safety_api is None: + return None + # Look up shields to get their provider_resource_id (actual model ID) model_ids = [] # TODO: list_shields not in Safety interface but available at runtime via API routing diff --git a/src/llama_stack/providers/registry/agents.py b/src/llama_stack/providers/registry/agents.py index 2c68750a6..e85be99d6 100644 --- a/src/llama_stack/providers/registry/agents.py +++ b/src/llama_stack/providers/registry/agents.py @@ -30,12 +30,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig", api_dependencies=[ Api.inference, - Api.safety, Api.vector_io, Api.tool_runtime, Api.tool_groups, Api.conversations, ], + optional_api_dependencies=[ + Api.safety, + ], description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.", ), ] diff --git a/tests/unit/providers/agents/meta_reference/test_safety_optional.py b/tests/unit/providers/agents/meta_reference/test_safety_optional.py new file mode 100644 index 000000000..b48d38b29 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_safety_optional.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Tests for making Safety API optional in meta-reference agents provider. + +This test suite validates the changes introduced to fix issue #4165, which +allows running the meta-reference agents provider without the Safety API. +Safety API is now an optional dependency, and errors are raised at request time +when guardrails are explicitly requested without Safety API configured. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llama_stack.core.datatypes import Api +from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference +from llama_stack.providers.inline.agents.meta_reference import get_provider_impl +from llama_stack.providers.inline.agents.meta_reference.config import ( + AgentPersistenceConfig, + MetaReferenceAgentsImplConfig, +) +from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + run_guardrails, +) + + +@pytest.fixture +def mock_persistence_config(): + """Create a mock persistence configuration.""" + return AgentPersistenceConfig( + agent_state=KVStoreReference( + backend="kv_default", + namespace="agents", + ), + responses=ResponsesStoreReference( + backend="sql_default", + table_name="responses", + ), + ) + + +@pytest.fixture +def mock_deps(): + """Create mock dependencies for the agents provider.""" + # Create mock APIs + inference_api = AsyncMock() + vector_io_api = AsyncMock() + tool_runtime_api = AsyncMock() + tool_groups_api = AsyncMock() + conversations_api = AsyncMock() + + return { + Api.inference: inference_api, + Api.vector_io: vector_io_api, + Api.tool_runtime: tool_runtime_api, + Api.tool_groups: tool_groups_api, + Api.conversations: conversations_api, + } + + +class TestProviderInitialization: + """Test provider initialization with different safety API configurations.""" + + async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps): + """Test successful initialization when Safety API is configured.""" + config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config) + + # Add safety API to deps + safety_api = AsyncMock() + mock_deps[Api.safety] = safety_api + + # Mock the initialize method to avoid actual initialization + with patch( + "llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize", + new_callable=AsyncMock, + ): + # Should not raise any exception + provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False) + assert provider is not None + + async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps): + """Test successful initialization when Safety API is not configured.""" + config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config) + + # Safety API is NOT in mock_deps - provider should still start + # Mock the initialize method to avoid actual initialization + with patch( + "llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize", + new_callable=AsyncMock, + ): + # Should not raise any exception + provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False) + assert provider is not None + assert provider.safety_api is None + + +class TestGuardrailsFunctionality: + """Test run_guardrails function with optional safety API.""" + + async def test_run_guardrails_with_none_safety_api(self): + """Test that run_guardrails returns None when safety_api is None.""" + result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"]) + assert result is None + + async def test_run_guardrails_with_empty_messages(self): + """Test that run_guardrails returns None for empty messages.""" + # Test with None safety API + result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"]) + assert result is None + + # Test with mock safety API + mock_safety_api = AsyncMock() + result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"]) + assert result is None + + async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self): + """Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided.""" + # Should not raise exception, just return None + result = await run_guardrails( + safety_api=None, + messages="potentially harmful content", + guardrail_ids=["llama-guard", "content-filter"], + ) + assert result is None + + async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps): + """Test that create_openai_response raises error when guardrails requested but Safety API unavailable.""" + from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, + ) + from llama_stack_api import ResponseGuardrailSpec + + # Create OpenAIResponsesImpl with no safety API + with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"): + impl = OpenAIResponsesImpl( + inference_api=mock_deps[Api.inference], + tool_groups_api=mock_deps[Api.tool_groups], + tool_runtime_api=mock_deps[Api.tool_runtime], + responses_store=MagicMock(), + vector_io_api=mock_deps[Api.vector_io], + safety_api=None, # No Safety API + conversations_api=mock_deps[Api.conversations], + ) + + # Test with string guardrail + with pytest.raises(ValueError) as exc_info: + await impl.create_openai_response( + input="test input", + model="test-model", + guardrails=["llama-guard"], + ) + assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value) + + # Test with ResponseGuardrailSpec + with pytest.raises(ValueError) as exc_info: + await impl.create_openai_response( + input="test input", + model="test-model", + guardrails=[ResponseGuardrailSpec(type="llama-guard")], + ) + assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value) + + async def test_create_response_succeeds_without_guardrails_and_no_safety_api( + self, mock_persistence_config, mock_deps + ): + """Test that create_openai_response works when no guardrails requested and Safety API unavailable.""" + from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, + ) + + # Create OpenAIResponsesImpl with no safety API + with ( + patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"), + patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream, + ): + # Mock the streaming response to return a simple async generator + async def mock_generator(): + yield MagicMock() + + mock_stream.return_value = mock_generator() + + impl = OpenAIResponsesImpl( + inference_api=mock_deps[Api.inference], + tool_groups_api=mock_deps[Api.tool_groups], + tool_runtime_api=mock_deps[Api.tool_runtime], + responses_store=MagicMock(), + vector_io_api=mock_deps[Api.vector_io], + safety_api=None, # No Safety API + conversations_api=mock_deps[Api.conversations], + ) + + # Should not raise when no guardrails requested + # Note: This will still fail later in execution due to mocking, but should pass the validation + try: + await impl.create_openai_response( + input="test input", + model="test-model", + guardrails=None, # No guardrails + ) + except Exception as e: + # Ensure the error is NOT about missing Safety API + assert "Cannot process guardrails: Safety API is not configured" not in str(e)