fix tests and remove unwanted changes

This commit is contained in:
Swapna Lekkala 2025-10-13 13:41:35 -07:00
parent c10db23d7a
commit 06dcfd1915
8 changed files with 36 additions and 30 deletions

View file

@ -3615,13 +3615,13 @@
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"input_guardrails": {
"input_shields": {
"type": "array",
"items": {
"type": "string"
}
},
"output_guardrails": {
"output_shields": {
"type": "array",
"items": {
"type": "string"

View file

@ -2667,11 +2667,11 @@ components:
properties:
sampling_params:
$ref: '#/components/schemas/SamplingParams'
input_guardrails:
input_shields:
type: array
items:
type: string
output_guardrails:
output_shields:
type: array
items:
type: string

View file

@ -2090,13 +2090,13 @@
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"input_guardrails": {
"input_shields": {
"type": "array",
"items": {
"type": "string"
}
},
"output_guardrails": {
"output_shields": {
"type": "array",
"items": {
"type": "string"

View file

@ -1500,11 +1500,11 @@ components:
properties:
sampling_params:
$ref: '#/components/schemas/SamplingParams'
input_guardrails:
input_shields:
type: array
items:
type: string
output_guardrails:
output_shields:
type: array
items:
type: string

View file

@ -15192,13 +15192,13 @@
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"input_guardrails": {
"input_shields": {
"type": "array",
"items": {
"type": "string"
}
},
"output_guardrails": {
"output_shields": {
"type": "array",
"items": {
"type": "string"

View file

@ -11478,11 +11478,11 @@ components:
properties:
sampling_params:
$ref: '#/components/schemas/SamplingParams'
input_guardrails:
input_shields:
type: array
items:
type: string
output_guardrails:
output_shields:
type: array
items:
type: string

View file

@ -218,8 +218,8 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_guardrails: list[str] | None = Field(default_factory=lambda: [])
output_guardrails: list[str] | None = Field(default_factory=lambda: [])
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")

View file

@ -13,7 +13,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
)
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
StreamingResponseOrchestrator,
@ -25,6 +24,14 @@ from llama_stack.providers.inline.agents.meta_reference.responses.types import C
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
# Mock run_moderation to return non-flagged result by default
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
return safety_api
@ -75,7 +82,7 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
guardrail_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
@ -92,27 +99,26 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=shield_ids,
guardrail_ids=guardrail_ids,
)
result = await orchestrator._check_input_safety(messages)
assert result is None
mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={})
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
# Get the actual call arguments
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
guardrail_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
mock_safety_api.run_shield.return_value = mock_response
# Mock moderation to return flagged content
mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True})
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
@ -124,13 +130,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=shield_ids,
guardrail_ids=guardrail_ids,
)
result = await orchestrator._check_input_safety(messages)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.refusal == "Content flagged by moderation"
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
@ -145,7 +151,7 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=[],
guardrail_ids=[],
)
# Test empty shield_ids
@ -153,6 +159,6 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
assert result is None
# Test empty messages
orchestrator.shield_ids = ["llama-guard"]
orchestrator.guardrail_ids = ["llama-guard"]
result = await orchestrator._check_input_safety([])
assert result is None