mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 12:02:37 +00:00
queue and stream events for safe chunk
This commit is contained in:
parent
31105c450a
commit
ada18ec399
3 changed files with 79 additions and 23 deletions
|
|
@ -566,6 +566,9 @@ class StreamingResponseOrchestrator:
|
|||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
# Track deltas for this specific chunk for guardrail validation
|
||||
chunk_events: list[OpenAIResponseObjectStream] = []
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
|
|
@ -601,15 +604,19 @@ class StreamingResponseOrchestrator:
|
|||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
# Skip Emitting text content delta event if guardrails are configured, only emits chunks after guardrails are applied
|
||||
if not self.guardrail_ids:
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=content_index,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=message_output_index,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
# Buffer text delta events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(text_delta_event)
|
||||
else:
|
||||
yield text_delta_event
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
|
|
@ -625,7 +632,11 @@ class StreamingResponseOrchestrator:
|
|||
message_item_id=message_item_id,
|
||||
message_output_index=message_output_index,
|
||||
):
|
||||
yield event
|
||||
# Buffer reasoning events for guardrail check
|
||||
if self.guardrail_ids:
|
||||
chunk_events.append(event)
|
||||
else:
|
||||
yield event
|
||||
reasoning_part_emitted = True
|
||||
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
|
||||
|
||||
|
|
@ -707,15 +718,21 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Output Safety Validation for a chunk
|
||||
# Output Safety Validation for this chunk
|
||||
if self.guardrail_ids:
|
||||
# Check guardrails on accumulated text so far
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"Output guardrail violation: {violation_message}")
|
||||
chunk_events.clear()
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
return
|
||||
else:
|
||||
# No violation detected, emit all content events for this chunk
|
||||
for event in chunk_events:
|
||||
yield event
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
|
|
|
|||
|
|
@ -323,13 +323,11 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
shields_list = await safety_api.routing_table.list_shields()
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
# Find the shield with this identifier
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
if matching_shields:
|
||||
model_id = matching_shields[0].provider_resource_id
|
||||
model_ids.append(model_id)
|
||||
else:
|
||||
# If no shield found, raise an error
|
||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue