mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 23:12:46 +00:00
fix tests and remove unwanted changes
This commit is contained in:
parent
c10db23d7a
commit
06dcfd1915
8 changed files with 36 additions and 30 deletions
4
docs/static/deprecated-llama-stack-spec.html
vendored
4
docs/static/deprecated-llama-stack-spec.html
vendored
|
|
@ -3615,13 +3615,13 @@
|
|||
"sampling_params": {
|
||||
"$ref": "#/components/schemas/SamplingParams"
|
||||
},
|
||||
"input_guardrails": {
|
||||
"input_shields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"output_guardrails": {
|
||||
"output_shields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
|
|||
4
docs/static/deprecated-llama-stack-spec.yaml
vendored
4
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -2667,11 +2667,11 @@ components:
|
|||
properties:
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
input_guardrails:
|
||||
input_shields:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_guardrails:
|
||||
output_shields:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
|
|||
|
|
@ -2090,13 +2090,13 @@
|
|||
"sampling_params": {
|
||||
"$ref": "#/components/schemas/SamplingParams"
|
||||
},
|
||||
"input_guardrails": {
|
||||
"input_shields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"output_guardrails": {
|
||||
"output_shields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
|
|||
|
|
@ -1500,11 +1500,11 @@ components:
|
|||
properties:
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
input_guardrails:
|
||||
input_shields:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_guardrails:
|
||||
output_shields:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
|
|||
4
docs/static/stainless-llama-stack-spec.html
vendored
4
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -15192,13 +15192,13 @@
|
|||
"sampling_params": {
|
||||
"$ref": "#/components/schemas/SamplingParams"
|
||||
},
|
||||
"input_guardrails": {
|
||||
"input_shields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"output_guardrails": {
|
||||
"output_shields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
|
|
|
|||
4
docs/static/stainless-llama-stack-spec.yaml
vendored
4
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -11478,11 +11478,11 @@ components:
|
|||
properties:
|
||||
sampling_params:
|
||||
$ref: '#/components/schemas/SamplingParams'
|
||||
input_guardrails:
|
||||
input_shields:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_guardrails:
|
||||
output_shields:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
|
|||
|
|
@ -218,8 +218,8 @@ register_schema(AgentToolGroup, name="AgentTool")
|
|||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_guardrails: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_guardrails: list[str] | None = Field(default_factory=lambda: [])
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
StreamingResponseOrchestrator,
|
||||
|
|
@ -25,6 +24,14 @@ from llama_stack.providers.inline.agents.meta_reference.responses.types import C
|
|||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
# Mock run_moderation to return non-flagged result by default
|
||||
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
|
||||
return safety_api
|
||||
|
||||
|
||||
|
|
@ -75,7 +82,7 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
|||
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
|
||||
"""Test input shield validation with no violations."""
|
||||
messages = [UserMessage(content="Hello world")]
|
||||
shield_ids = ["llama-guard"]
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock successful shield validation (no violation)
|
||||
mock_response = AsyncMock()
|
||||
|
|
@ -92,27 +99,26 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
|
|||
max_infer_iters=5,
|
||||
tool_executor=AsyncMock(),
|
||||
safety_api=mock_safety_api,
|
||||
shield_ids=shield_ids,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
result = await orchestrator._check_input_safety(messages)
|
||||
|
||||
assert result is None
|
||||
mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={})
|
||||
# Verify run_moderation was called with the correct model
|
||||
mock_safety_api.run_moderation.assert_called_once()
|
||||
# Get the actual call arguments
|
||||
call_args = mock_safety_api.run_moderation.call_args
|
||||
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
|
||||
|
||||
|
||||
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
|
||||
"""Test input shield validation with safety violation."""
|
||||
messages = [UserMessage(content="Harmful content")]
|
||||
shield_ids = ["llama-guard"]
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock shield violation
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
|
||||
)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = violation
|
||||
mock_safety_api.run_shield.return_value = mock_response
|
||||
# Mock moderation to return flagged content
|
||||
mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True})
|
||||
|
||||
# Create orchestrator with safety components
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
|
|
@ -124,13 +130,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference
|
|||
max_infer_iters=5,
|
||||
tool_executor=AsyncMock(),
|
||||
safety_api=mock_safety_api,
|
||||
shield_ids=shield_ids,
|
||||
guardrail_ids=guardrail_ids,
|
||||
)
|
||||
|
||||
result = await orchestrator._check_input_safety(messages)
|
||||
|
||||
assert isinstance(result, OpenAIResponseContentPartRefusal)
|
||||
assert result.refusal == "Content violates safety guidelines"
|
||||
assert result.refusal == "Content flagged by moderation"
|
||||
|
||||
|
||||
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
|
||||
|
|
@ -145,7 +151,7 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
|
|||
max_infer_iters=5,
|
||||
tool_executor=AsyncMock(),
|
||||
safety_api=mock_safety_api,
|
||||
shield_ids=[],
|
||||
guardrail_ids=[],
|
||||
)
|
||||
|
||||
# Test empty shield_ids
|
||||
|
|
@ -153,6 +159,6 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
|
|||
assert result is None
|
||||
|
||||
# Test empty messages
|
||||
orchestrator.shield_ids = ["llama-guard"]
|
||||
orchestrator.guardrail_ids = ["llama-guard"]
|
||||
result = await orchestrator._check_input_safety([])
|
||||
assert result is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue