mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-19 15:38:52 +00:00
feat: Add responses and safety impl extra_body (#3781)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 9s
Test External API and Providers / test-external (venv) (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 19s
UI Tests / ui-tests (22) (push) Successful in 37s
Pre-commit / pre-commit (push) Successful in 1m33s
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 9s
Test External API and Providers / test-external (venv) (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 7s
Unit Tests / unit-tests (3.12) (push) Failing after 9s
API Conformance Tests / check-schema-compatibility (push) Successful in 19s
UI Tests / ui-tests (22) (push) Successful in 37s
Pre-commit / pre-commit (push) Successful in 1m33s
# What does this PR do? Have closed the previous PR due to merge conflicts with multiple PRs Addressed all comments from https://github.com/llamastack/llama-stack/pull/3768 (sorry for carrying over to this one) ## Test Plan Added UTs and integration tests
This commit is contained in:
parent
8e7e0ddfec
commit
99141c29b1
244 changed files with 36829 additions and 235 deletions
|
@ -91,6 +91,12 @@ def mock_conversations_api():
|
|||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api,
|
||||
|
@ -98,6 +104,7 @@ def openai_responses_impl(
|
|||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_safety_api,
|
||||
mock_conversations_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
|
@ -106,6 +113,7 @@ def openai_responses_impl(
|
|||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
safety_api=mock_safety_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
)
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ def responses_impl_with_conversations(
|
|||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
|
@ -47,6 +48,7 @@ def responses_impl_with_conversations(
|
|||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_guardrail_ids,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
"""Create mock APIs for testing."""
|
||||
return {
|
||||
"inference_api": AsyncMock(),
|
||||
"tool_groups_api": AsyncMock(),
|
||||
"tool_runtime_api": AsyncMock(),
|
||||
"responses_store": AsyncMock(),
|
||||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl(mock_apis):
|
||||
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
|
||||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
"""Create mock safety API for guardrails testing."""
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
return safety_api
|
||||
|
||||
|
||||
async def test_run_guardrails_no_violation(mock_safety_api):
|
||||
"""Test guardrails validation with no violations."""
|
||||
text = "Hello world"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return non-flagged content
|
||||
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result is None
|
||||
# Verify run_moderation was called with the correct model
|
||||
mock_safety_api.run_moderation.assert_called_once()
|
||||
call_args = mock_safety_api.run_moderation.call_args
|
||||
assert call_args[1]["model"] == "llama-guard-model"
|
||||
|
||||
|
||||
async def test_run_guardrails_with_violation(mock_safety_api):
|
||||
"""Test guardrails validation with safety violation."""
|
||||
text = "Harmful content"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return flagged content
|
||||
flagged_result = ModerationObjectResults(
|
||||
flagged=True,
|
||||
categories={"violence": True},
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"violation_type": ["S1"]},
|
||||
)
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
|
||||
|
||||
|
||||
async def test_run_guardrails_empty_inputs(mock_safety_api):
|
||||
"""Test guardrails validation with empty inputs."""
|
||||
# Test empty guardrail_ids
|
||||
result = await run_guardrails(mock_safety_api, "test", [])
|
||||
assert result is None
|
||||
|
||||
# Test empty text
|
||||
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test both empty
|
||||
result = await run_guardrails(mock_safety_api, "", [])
|
||||
assert result is None
|
Loading…
Add table
Add a link
Reference in a new issue