mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 13:02:38 +00:00
queue and stream events for safe chunk
This commit is contained in:
parent
31105c450a
commit
ada18ec399
3 changed files with 79 additions and 23 deletions
|
|
@ -566,6 +566,9 @@ class StreamingResponseOrchestrator:
|
||||||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||||
self._accumulate_chunk_usage(chunk)
|
self._accumulate_chunk_usage(chunk)
|
||||||
|
|
||||||
|
# Track deltas for this specific chunk for guardrail validation
|
||||||
|
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||||
|
|
||||||
for chunk_choice in chunk.choices:
|
for chunk_choice in chunk.choices:
|
||||||
# Emit incremental text content as delta events
|
# Emit incremental text content as delta events
|
||||||
if chunk_choice.delta.content:
|
if chunk_choice.delta.content:
|
||||||
|
|
@ -601,15 +604,19 @@ class StreamingResponseOrchestrator:
|
||||||
sequence_number=self.sequence_number,
|
sequence_number=self.sequence_number,
|
||||||
)
|
)
|
||||||
self.sequence_number += 1
|
self.sequence_number += 1
|
||||||
# Skip Emitting text content delta event if guardrails are configured, only emits chunks after guardrails are applied
|
|
||||||
if not self.guardrail_ids:
|
text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
content_index=content_index,
|
||||||
content_index=content_index,
|
delta=chunk_choice.delta.content,
|
||||||
delta=chunk_choice.delta.content,
|
item_id=message_item_id,
|
||||||
item_id=message_item_id,
|
output_index=message_output_index,
|
||||||
output_index=message_output_index,
|
sequence_number=self.sequence_number,
|
||||||
sequence_number=self.sequence_number,
|
)
|
||||||
)
|
# Buffer text delta events for guardrail check
|
||||||
|
if self.guardrail_ids:
|
||||||
|
chunk_events.append(text_delta_event)
|
||||||
|
else:
|
||||||
|
yield text_delta_event
|
||||||
|
|
||||||
# Collect content for final response
|
# Collect content for final response
|
||||||
chat_response_content.append(chunk_choice.delta.content or "")
|
chat_response_content.append(chunk_choice.delta.content or "")
|
||||||
|
|
@ -625,7 +632,11 @@ class StreamingResponseOrchestrator:
|
||||||
message_item_id=message_item_id,
|
message_item_id=message_item_id,
|
||||||
message_output_index=message_output_index,
|
message_output_index=message_output_index,
|
||||||
):
|
):
|
||||||
yield event
|
# Buffer reasoning events for guardrail check
|
||||||
|
if self.guardrail_ids:
|
||||||
|
chunk_events.append(event)
|
||||||
|
else:
|
||||||
|
yield event
|
||||||
reasoning_part_emitted = True
|
reasoning_part_emitted = True
|
||||||
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
||||||
|
|
||||||
|
|
@ -707,15 +718,21 @@ class StreamingResponseOrchestrator:
|
||||||
response_tool_call.function.arguments or ""
|
response_tool_call.function.arguments or ""
|
||||||
) + tool_call.function.arguments
|
) + tool_call.function.arguments
|
||||||
|
|
||||||
# Output Safety Validation for a chunk
|
# Output Safety Validation for this chunk
|
||||||
if self.guardrail_ids:
|
if self.guardrail_ids:
|
||||||
|
# Check guardrails on accumulated text so far
|
||||||
accumulated_text = "".join(chat_response_content)
|
accumulated_text = "".join(chat_response_content)
|
||||||
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||||
if violation_message:
|
if violation_message:
|
||||||
logger.info(f"Output guardrail violation: {violation_message}")
|
logger.info(f"Output guardrail violation: {violation_message}")
|
||||||
|
chunk_events.clear()
|
||||||
yield await self._create_refusal_response(violation_message)
|
yield await self._create_refusal_response(violation_message)
|
||||||
self.violation_detected = True
|
self.violation_detected = True
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
# No violation detected, emit all content events for this chunk
|
||||||
|
for event in chunk_events:
|
||||||
|
yield event
|
||||||
|
|
||||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||||
|
|
|
||||||
|
|
@ -323,13 +323,11 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
shields_list = await safety_api.routing_table.list_shields()
|
shields_list = await safety_api.routing_table.list_shields()
|
||||||
|
|
||||||
for guardrail_id in guardrail_ids:
|
for guardrail_id in guardrail_ids:
|
||||||
# Find the shield with this identifier
|
|
||||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||||
if matching_shields:
|
if matching_shields:
|
||||||
model_id = matching_shields[0].provider_resource_id
|
model_id = matching_shields[0].provider_resource_id
|
||||||
model_ids.append(model_id)
|
model_ids.append(model_id)
|
||||||
else:
|
else:
|
||||||
# If no shield found, raise an error
|
|
||||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||||
|
|
||||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||||
|
|
|
||||||
|
|
@ -346,22 +346,18 @@ def test_input_safety_guardrails(compat_client, text_model_id, stream):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("stream", [True, False])
|
@pytest.mark.parametrize("stream", [True, False])
|
||||||
def test_output_safety_guardrails(compat_client, text_model_id, stream):
|
def test_output_safety_guardrails_unsafe_content(compat_client, text_model_id, stream):
|
||||||
"""Test that output safety guardrails can block generated content."""
|
"""Test that output safety guardrails block unsafe content."""
|
||||||
if not isinstance(compat_client, OpenAI):
|
if not isinstance(compat_client, OpenAI):
|
||||||
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||||
|
|
||||||
# Use an ambiguous prompt that could potentially generate content requiring guardrail evaluation
|
# Use an unsafe prompt that should be blocked by guardrail evaluation
|
||||||
# This prompt: "Describe violent content."
|
|
||||||
# Could result in:
|
|
||||||
# 1. Educational violent content → guardrail allows output → normal response
|
|
||||||
# 2. Violent techniques content → guardrail blocks output → refusal response
|
|
||||||
response = compat_client.responses.create(
|
response = compat_client.responses.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
input=[
|
input=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Describe violent content.",
|
"content": "Describe violent content",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
@ -369,7 +365,7 @@ def test_output_safety_guardrails(compat_client, text_model_id, stream):
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
# For streaming, violations cause stream to end with refusal completion
|
# For streaming, collect the final response
|
||||||
last_chunk = None
|
last_chunk = None
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
last_chunk = chunk
|
last_chunk = chunk
|
||||||
|
|
@ -385,7 +381,52 @@ def test_output_safety_guardrails(compat_client, text_model_id, stream):
|
||||||
|
|
||||||
assert len(message.content) > 0, "Message should have content"
|
assert len(message.content) > 0, "Message should have content"
|
||||||
content_item = message.content[0]
|
content_item = message.content[0]
|
||||||
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
|
assert content_item.type == "refusal", (
|
||||||
|
f"Content type should be 'refusal' for unsafe output, got {content_item.type}"
|
||||||
|
)
|
||||||
|
assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [True, False])
|
||||||
|
def test_output_safety_guardrails_safe_content(compat_client, text_model_id, stream):
|
||||||
|
"""Test that output safety guardrails allow safe content."""
|
||||||
|
if not isinstance(compat_client, OpenAI):
|
||||||
|
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
|
||||||
|
|
||||||
|
# Use a safe prompt that should pass guardrail evaluation
|
||||||
|
response = compat_client.responses.create(
|
||||||
|
model=text_model_id,
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's your name?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=stream,
|
||||||
|
extra_body={"guardrails": ["llama-guard"]}, # Output guardrail validation
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
# For streaming, collect the final response
|
||||||
|
last_chunk = None
|
||||||
|
for chunk in response:
|
||||||
|
last_chunk = chunk
|
||||||
|
|
||||||
|
assert last_chunk is not None
|
||||||
|
assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}"
|
||||||
|
response_to_check = last_chunk.response
|
||||||
|
else:
|
||||||
|
response_to_check = response
|
||||||
|
|
||||||
|
assert response_to_check.output[0].type == "message"
|
||||||
|
message = response_to_check.output[0]
|
||||||
|
|
||||||
|
assert len(message.content) > 0, "Message should have content"
|
||||||
|
content_item = message.content[0]
|
||||||
|
assert content_item.type == "output_text", (
|
||||||
|
f"Content type should be 'output_text' for safe output, got {content_item.type}"
|
||||||
|
)
|
||||||
|
assert len(content_item.text.strip()) > 0, "Text content should not be empty"
|
||||||
|
|
||||||
|
|
||||||
def test_guardrails_with_tools(compat_client, text_model_id):
|
def test_guardrails_with_tools(compat_client, text_model_id):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue