fix tests and remove unwanted changes

This commit is contained in:
Swapna Lekkala 2025-10-13 13:41:35 -07:00
parent c10db23d7a
commit 06dcfd1915
8 changed files with 36 additions and 30 deletions

View file

@ -3615,13 +3615,13 @@
"sampling_params": { "sampling_params": {
"$ref": "#/components/schemas/SamplingParams" "$ref": "#/components/schemas/SamplingParams"
}, },
"input_guardrails": { "input_shields": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} }
}, },
"output_guardrails": { "output_shields": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"

View file

@ -2667,11 +2667,11 @@ components:
properties: properties:
sampling_params: sampling_params:
$ref: '#/components/schemas/SamplingParams' $ref: '#/components/schemas/SamplingParams'
input_guardrails: input_shields:
type: array type: array
items: items:
type: string type: string
output_guardrails: output_shields:
type: array type: array
items: items:
type: string type: string

View file

@ -2090,13 +2090,13 @@
"sampling_params": { "sampling_params": {
"$ref": "#/components/schemas/SamplingParams" "$ref": "#/components/schemas/SamplingParams"
}, },
"input_guardrails": { "input_shields": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} }
}, },
"output_guardrails": { "output_shields": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"

View file

@ -1500,11 +1500,11 @@ components:
properties: properties:
sampling_params: sampling_params:
$ref: '#/components/schemas/SamplingParams' $ref: '#/components/schemas/SamplingParams'
input_guardrails: input_shields:
type: array type: array
items: items:
type: string type: string
output_guardrails: output_shields:
type: array type: array
items: items:
type: string type: string

View file

@ -15192,13 +15192,13 @@
"sampling_params": { "sampling_params": {
"$ref": "#/components/schemas/SamplingParams" "$ref": "#/components/schemas/SamplingParams"
}, },
"input_guardrails": { "input_shields": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} }
}, },
"output_guardrails": { "output_shields": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"

View file

@ -11478,11 +11478,11 @@ components:
properties: properties:
sampling_params: sampling_params:
$ref: '#/components/schemas/SamplingParams' $ref: '#/components/schemas/SamplingParams'
input_guardrails: input_shields:
type: array type: array
items: items:
type: string type: string
output_guardrails: output_shields:
type: array type: array
items: items:
type: string type: string

View file

@ -218,8 +218,8 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_guardrails: list[str] | None = Field(default_factory=lambda: []) input_shields: list[str] | None = Field(default_factory=lambda: [])
output_guardrails: list[str] | None = Field(default_factory=lambda: []) output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: []) client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")

View file

@ -13,7 +13,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText, OpenAIResponseText,
) )
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
StreamingResponseOrchestrator, StreamingResponseOrchestrator,
@ -25,6 +24,14 @@ from llama_stack.providers.inline.agents.meta_reference.responses.types import C
@pytest.fixture @pytest.fixture
def mock_safety_api(): def mock_safety_api():
safety_api = AsyncMock() safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
# Mock run_moderation to return non-flagged result by default
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
return safety_api return safety_api
@ -75,7 +82,7 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context): async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with no violations.""" """Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")] messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"] guardrail_ids = ["llama-guard"]
# Mock successful shield validation (no violation) # Mock successful shield validation (no violation)
mock_response = AsyncMock() mock_response = AsyncMock()
@ -92,27 +99,26 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
max_infer_iters=5, max_infer_iters=5,
tool_executor=AsyncMock(), tool_executor=AsyncMock(),
safety_api=mock_safety_api, safety_api=mock_safety_api,
shield_ids=shield_ids, guardrail_ids=guardrail_ids,
) )
result = await orchestrator._check_input_safety(messages) result = await orchestrator._check_input_safety(messages)
assert result is None assert result is None
mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={}) # Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
# Get the actual call arguments
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context): async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with safety violation.""" """Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")] messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"] guardrail_ids = ["llama-guard"]
# Mock shield violation # Mock moderation to return flagged content
violation = SafetyViolation( mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True})
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
mock_safety_api.run_shield.return_value = mock_response
# Create orchestrator with safety components # Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator( orchestrator = StreamingResponseOrchestrator(
@ -124,13 +130,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference
max_infer_iters=5, max_infer_iters=5,
tool_executor=AsyncMock(), tool_executor=AsyncMock(),
safety_api=mock_safety_api, safety_api=mock_safety_api,
shield_ids=shield_ids, guardrail_ids=guardrail_ids,
) )
result = await orchestrator._check_input_safety(messages) result = await orchestrator._check_input_safety(messages)
assert isinstance(result, OpenAIResponseContentPartRefusal) assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines" assert result.refusal == "Content flagged by moderation"
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context): async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
@ -145,7 +151,7 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
max_infer_iters=5, max_infer_iters=5,
tool_executor=AsyncMock(), tool_executor=AsyncMock(),
safety_api=mock_safety_api, safety_api=mock_safety_api,
shield_ids=[], guardrail_ids=[],
) )
# Test empty shield_ids # Test empty shield_ids
@ -153,6 +159,6 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
assert result is None assert result is None
# Test empty messages # Test empty messages
orchestrator.shield_ids = ["llama-guard"] orchestrator.guardrail_ids = ["llama-guard"]
result = await orchestrator._check_input_safety([]) result = await orchestrator._check_input_safety([])
assert result is None assert result is None