mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
feat: Make Safety API an optional dependency for meta-reference agents provider (#4169)
# What does this PR do?
Change Safety API from required to optional dependency, following the
established pattern used for other optional dependencies in Llama Stack.
The provider now starts successfully without Safety API configured.
Requests that explicitly include guardrails will receive a clear error
message when Safety API is unavailable.
This enables local development and testing without Safety API while
maintaining clear error messages when guardrail features are requested.
Closes #4165
Signed-off-by: Anik Bhattacharjee <anbhatta@redhat.com>
## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
1. New unit tests added in
`tests/unit/providers/agents/meta_reference/test_safety_optional.py`
2. Integration tests performed with the files in
https://gist.github.com/anik120/c33cef497ec7085e1fe2164e0705b8d6
(i) test with `test_integration_no_safety_fail.yaml`:
Config WITHOUT Safety API, should fail with helpful error since
`required_safety_api` is `true` by default
```
$ uv run llama stack run test_integration_no_safety_fail.yaml 2>&1 | grep -B 5 -A 15 "ValueError.*Safety\|Safety API is
required"
File "/Users/anbhatta/go/src/github.com/llamastack/llama-stack/src/llama_stack/providers/inline/agents/meta_reference
/__init__.py", line 27, in get_provider_impl
raise ValueError(
...<9 lines>...
)
ValueError: Safety API is required but not configured.
To run without safety checks, explicitly set in your configuration:
providers:
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
require_safety_api: false
Warning: This disables all safety guardrails for this agents provider.
```
(ii) test with `test_integration_no_safety_works.yaml`
Config WITHOUT Safety API, **but** `require_safety_api=false` is
explicitly set, should succeed
```
$ uv run llama stack run test_integration_no_safety_works.yaml
INFO 2025-11-16 09:49:10,044 llama_stack.cli.stack.run:169 cli: Using run configuration:
/Users/anbhatta/go/src/github.com/llamastack/llama-stack/test_integration_no_safety_works.yaml
INFO 2025-11-16 09:49:10,052 llama_stack.cli.stack.run:228 cli: HTTPS enabled with certificates:
Key: None
Cert: None
.
.
.
INFO 2025-11-16 09:49:38,528 llama_stack.core.stack:495 core: starting registry refresh task
INFO 2025-11-16 09:49:38,534 uvicorn.error:62 uncategorized: Application startup complete.
INFO 2025-11-16 09:49:38,535 uvicorn.error:216 uncategorized: Uvicorn running on http://0.0.0.0:8321 (Press CTRL+C
```
Signed-off-by: Anik Bhattacharjee <anbhatta@redhat.com>
Signed-off-by: Anik Bhattacharjee <anbhatta@redhat.com>
This commit is contained in:
parent
d5cd0eea14
commit
4e9633f7c3
7 changed files with 227 additions and 6 deletions
|
|
@ -23,7 +23,7 @@ async def get_provider_impl(
|
|||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.vector_io],
|
||||
deps[Api.safety],
|
||||
deps.get(Api.safety),
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
deps[Api.conversations],
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
config: MetaReferenceAgentsImplConfig,
|
||||
inference_api: Inference,
|
||||
vector_io_api: VectorIO,
|
||||
safety_api: Safety,
|
||||
safety_api: Safety | None,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
conversations_api: Conversations,
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class OpenAIResponsesImpl:
|
|||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
safety_api: Safety,
|
||||
safety_api: Safety | None,
|
||||
conversations_api: Conversations,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -273,6 +273,14 @@ class OpenAIResponsesImpl:
|
|||
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
# Validate that Safety API is available if guardrails are requested
|
||||
if guardrail_ids and self.safety_api is None:
|
||||
raise ValueError(
|
||||
"Cannot process guardrails: Safety API is not configured.\n\n"
|
||||
"To use guardrails, ensure the Safety API is configured in your stack, or remove "
|
||||
"the 'guardrails' parameter from your request."
|
||||
)
|
||||
|
||||
if conversation is not None:
|
||||
if previous_response_id is not None:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ from llama_stack_api import (
|
|||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
Safety,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
|
||||
|
|
@ -111,7 +112,7 @@ class StreamingResponseOrchestrator:
|
|||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str | None,
|
||||
safety_api,
|
||||
safety_api: Safety | None,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
|
|
|
|||
|
|
@ -320,11 +320,15 @@ def is_function_tool_call(
|
|||
return False
|
||||
|
||||
|
||||
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run guardrails against messages and return violation message if blocked."""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# If safety API is not available, skip guardrails
|
||||
if safety_api is None:
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||
|
|
|
|||
|
|
@ -30,12 +30,14 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
Api.vector_io,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
Api.conversations,
|
||||
],
|
||||
optional_api_dependencies=[
|
||||
Api.safety,
|
||||
],
|
||||
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Tests for making Safety API optional in meta-reference agents provider.
|
||||
|
||||
This test suite validates the changes introduced to fix issue #4165, which
|
||||
allows running the meta-reference agents provider without the Safety API.
|
||||
Safety API is now an optional dependency, and errors are raised at request time
|
||||
when guardrails are explicitly requested without Safety API configured.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import (
|
||||
AgentPersistenceConfig,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_persistence_config():
|
||||
"""Create a mock persistence configuration."""
|
||||
return AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Create mock dependencies for the agents provider."""
|
||||
# Create mock APIs
|
||||
inference_api = AsyncMock()
|
||||
vector_io_api = AsyncMock()
|
||||
tool_runtime_api = AsyncMock()
|
||||
tool_groups_api = AsyncMock()
|
||||
conversations_api = AsyncMock()
|
||||
|
||||
return {
|
||||
Api.inference: inference_api,
|
||||
Api.vector_io: vector_io_api,
|
||||
Api.tool_runtime: tool_runtime_api,
|
||||
Api.tool_groups: tool_groups_api,
|
||||
Api.conversations: conversations_api,
|
||||
}
|
||||
|
||||
|
||||
class TestProviderInitialization:
|
||||
"""Test provider initialization with different safety API configurations."""
|
||||
|
||||
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
|
||||
"""Test successful initialization when Safety API is configured."""
|
||||
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
|
||||
|
||||
# Add safety API to deps
|
||||
safety_api = AsyncMock()
|
||||
mock_deps[Api.safety] = safety_api
|
||||
|
||||
# Mock the initialize method to avoid actual initialization
|
||||
with patch(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
assert provider is not None
|
||||
|
||||
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
"""Test successful initialization when Safety API is not configured."""
|
||||
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
|
||||
|
||||
# Safety API is NOT in mock_deps - provider should still start
|
||||
# Mock the initialize method to avoid actual initialization
|
||||
with patch(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
assert provider is not None
|
||||
assert provider.safety_api is None
|
||||
|
||||
|
||||
class TestGuardrailsFunctionality:
|
||||
"""Test run_guardrails function with optional safety API."""
|
||||
|
||||
async def test_run_guardrails_with_none_safety_api(self):
|
||||
"""Test that run_guardrails returns None when safety_api is None."""
|
||||
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
async def test_run_guardrails_with_empty_messages(self):
|
||||
"""Test that run_guardrails returns None for empty messages."""
|
||||
# Test with None safety API
|
||||
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test with mock safety API
|
||||
mock_safety_api = AsyncMock()
|
||||
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
|
||||
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
|
||||
# Should not raise exception, just return None
|
||||
result = await run_guardrails(
|
||||
safety_api=None,
|
||||
messages="potentially harmful content",
|
||||
guardrail_ids=["llama-guard", "content-filter"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack_api import ResponseGuardrailSpec
|
||||
|
||||
# Create OpenAIResponsesImpl with no safety API
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
|
||||
impl = OpenAIResponsesImpl(
|
||||
inference_api=mock_deps[Api.inference],
|
||||
tool_groups_api=mock_deps[Api.tool_groups],
|
||||
tool_runtime_api=mock_deps[Api.tool_runtime],
|
||||
responses_store=MagicMock(),
|
||||
vector_io_api=mock_deps[Api.vector_io],
|
||||
safety_api=None, # No Safety API
|
||||
conversations_api=mock_deps[Api.conversations],
|
||||
)
|
||||
|
||||
# Test with string guardrail
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=["llama-guard"],
|
||||
)
|
||||
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
|
||||
|
||||
# Test with ResponseGuardrailSpec
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
|
||||
)
|
||||
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
|
||||
self, mock_persistence_config, mock_deps
|
||||
):
|
||||
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
# Create OpenAIResponsesImpl with no safety API
|
||||
with (
|
||||
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
|
||||
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
|
||||
):
|
||||
# Mock the streaming response to return a simple async generator
|
||||
async def mock_generator():
|
||||
yield MagicMock()
|
||||
|
||||
mock_stream.return_value = mock_generator()
|
||||
|
||||
impl = OpenAIResponsesImpl(
|
||||
inference_api=mock_deps[Api.inference],
|
||||
tool_groups_api=mock_deps[Api.tool_groups],
|
||||
tool_runtime_api=mock_deps[Api.tool_runtime],
|
||||
responses_store=MagicMock(),
|
||||
vector_io_api=mock_deps[Api.vector_io],
|
||||
safety_api=None, # No Safety API
|
||||
conversations_api=mock_deps[Api.conversations],
|
||||
)
|
||||
|
||||
# Should not raise when no guardrails requested
|
||||
# Note: This will still fail later in execution due to mocking, but should pass the validation
|
||||
try:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=None, # No guardrails
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure the error is NOT about missing Safety API
|
||||
assert "Cannot process guardrails: Safety API is not configured" not in str(e)
|
||||
Loading…
Add table
Add a link
Reference in a new issue