From 06dcfd1915ee098320e666fbb080e288d9c87093 Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Mon, 13 Oct 2025 13:41:35 -0700 Subject: [PATCH] fix tests and remove unwanted changes --- docs/static/deprecated-llama-stack-spec.html | 4 +- docs/static/deprecated-llama-stack-spec.yaml | 4 +- .../static/experimental-llama-stack-spec.html | 4 +- .../static/experimental-llama-stack-spec.yaml | 4 +- docs/static/stainless-llama-stack-spec.html | 4 +- docs/static/stainless-llama-stack-spec.yaml | 4 +- llama_stack/apis/agents/agents.py | 4 +- .../responses/test_streaming.py | 38 +++++++++++-------- 8 files changed, 36 insertions(+), 30 deletions(-) diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 426312a2f..e5c02381b 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -3615,13 +3615,13 @@ "sampling_params": { "$ref": "#/components/schemas/SamplingParams" }, - "input_guardrails": { + "input_shields": { "type": "array", "items": { "type": "string" } }, - "output_guardrails": { + "output_shields": { "type": "array", "items": { "type": "string" diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 2da7af8aa..43f748d14 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -2667,11 +2667,11 @@ components: properties: sampling_params: $ref: '#/components/schemas/SamplingParams' - input_guardrails: + input_shields: type: array items: type: string - output_guardrails: + output_shields: type: array items: type: string diff --git a/docs/static/experimental-llama-stack-spec.html b/docs/static/experimental-llama-stack-spec.html index faac1b91e..e3edf2ffc 100644 --- a/docs/static/experimental-llama-stack-spec.html +++ b/docs/static/experimental-llama-stack-spec.html @@ -2090,13 +2090,13 @@ "sampling_params": { "$ref": "#/components/schemas/SamplingParams" }, - "input_guardrails": { + "input_shields": { "type": "array", "items": { "type": "string" } }, - "output_guardrails": { + "output_shields": { "type": "array", "items": { "type": "string" diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 68d85f1df..7ee5a6cdf 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -1500,11 +1500,11 @@ components: properties: sampling_params: $ref: '#/components/schemas/SamplingParams' - input_guardrails: + input_shields: type: array items: type: string - output_guardrails: + output_shields: type: array items: type: string diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 78e32ab2b..4b3b22886 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -15192,13 +15192,13 @@ "sampling_params": { "$ref": "#/components/schemas/SamplingParams" }, - "input_guardrails": { + "input_shields": { "type": "array", "items": { "type": "string" } }, - "output_guardrails": { + "output_shields": { "type": "array", "items": { "type": "string" diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 6d7faa3d5..2e3d59ae0 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -11478,11 +11478,11 @@ components: properties: sampling_params: $ref: '#/components/schemas/SamplingParams' - input_guardrails: + input_shields: type: array items: type: string - output_guardrails: + output_shields: type: array items: type: string diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index bf6fe507c..6ad45cf99 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -218,8 +218,8 @@ register_schema(AgentToolGroup, name="AgentTool") class AgentConfigCommon(BaseModel): sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - input_guardrails: list[str] | None = Field(default_factory=lambda: []) - output_guardrails: list[str] | None = Field(default_factory=lambda: []) + input_shields: list[str] | None = Field(default_factory=lambda: []) + output_shields: list[str] | None = Field(default_factory=lambda: []) toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) client_tools: list[ToolDef] | None = Field(default_factory=lambda: []) tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py index 343ebc0b3..8f657eed3 100644 --- a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -13,7 +13,6 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseText, ) from llama_stack.apis.inference import UserMessage -from llama_stack.apis.safety import SafetyViolation, ViolationLevel from llama_stack.apis.tools import ToolDef from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( StreamingResponseOrchestrator, @@ -25,6 +24,14 @@ from llama_stack.providers.inline.agents.meta_reference.responses.types import C @pytest.fixture def mock_safety_api(): safety_api = AsyncMock() + # Mock the routing table and shields list for guardrails lookup + safety_api.routing_table = AsyncMock() + shield = AsyncMock() + shield.identifier = "llama-guard" + shield.provider_resource_id = "llama-guard-model" + safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield]) + # Mock run_moderation to return non-flagged result by default + safety_api.run_moderation.return_value = AsyncMock(flagged=False) return safety_api @@ -75,7 +82,7 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field(): async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context): """Test input shield validation with no violations.""" messages = [UserMessage(content="Hello world")] - shield_ids = ["llama-guard"] + guardrail_ids = ["llama-guard"] # Mock successful shield validation (no violation) mock_response = AsyncMock() @@ -92,27 +99,26 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a max_infer_iters=5, tool_executor=AsyncMock(), safety_api=mock_safety_api, - shield_ids=shield_ids, + guardrail_ids=guardrail_ids, ) result = await orchestrator._check_input_safety(messages) assert result is None - mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={}) + # Verify run_moderation was called with the correct model + mock_safety_api.run_moderation.assert_called_once() + # Get the actual call arguments + call_args = mock_safety_api.run_moderation.call_args + assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context): """Test input shield validation with safety violation.""" messages = [UserMessage(content="Harmful content")] - shield_ids = ["llama-guard"] + guardrail_ids = ["llama-guard"] - # Mock shield violation - violation = SafetyViolation( - violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={} - ) - mock_response = AsyncMock() - mock_response.violation = violation - mock_safety_api.run_shield.return_value = mock_response + # Mock moderation to return flagged content + mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True}) # Create orchestrator with safety components orchestrator = StreamingResponseOrchestrator( @@ -124,13 +130,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference max_infer_iters=5, tool_executor=AsyncMock(), safety_api=mock_safety_api, - shield_ids=shield_ids, + guardrail_ids=guardrail_ids, ) result = await orchestrator._check_input_safety(messages) assert isinstance(result, OpenAIResponseContentPartRefusal) - assert result.refusal == "Content violates safety guidelines" + assert result.refusal == "Content flagged by moderation" async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context): @@ -145,7 +151,7 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a max_infer_iters=5, tool_executor=AsyncMock(), safety_api=mock_safety_api, - shield_ids=[], + guardrail_ids=[], ) # Test empty shield_ids @@ -153,6 +159,6 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a assert result is None # Test empty messages - orchestrator.shield_ids = ["llama-guard"] + orchestrator.guardrail_ids = ["llama-guard"] result = await orchestrator._check_input_safety([]) assert result is None