mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 08:33:09 +00:00
Merge branch 'main' into responses-and-safety
This commit is contained in:
commit
90ee3001d9
33 changed files with 16970 additions and 713 deletions
|
@ -41,12 +41,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseUsage,
|
||||
OpenAIResponseUsageInputTokensDetails,
|
||||
OpenAIResponseUsageOutputTokensDetails,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
|
@ -116,43 +120,8 @@ class StreamingResponseOrchestrator:
|
|||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
self.citation_files: dict[str, str] = {}
|
||||
# Track accumulated text for shield validation
|
||||
self.accumulated_text = ""
|
||||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_output_stream_safety(self, text_delta: str) -> str | None:
|
||||
"""Check streaming text content against shields. Returns violation message if blocked."""
|
||||
if not self.shield_ids:
|
||||
return None
|
||||
|
||||
self.accumulated_text += text_delta
|
||||
|
||||
# Check accumulated text periodically for violations (every 50 characters or at word boundaries)
|
||||
if len(self.accumulated_text) > 50 or text_delta.endswith((" ", "\n", ".", "!", "?")):
|
||||
temp_messages = [{"role": "assistant", "content": self.accumulated_text}]
|
||||
messages = convert_openai_to_inference_messages(temp_messages)
|
||||
|
||||
try:
|
||||
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Output shield violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety shields"
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
||||
# Create a completed refusal response
|
||||
refusal_response = OpenAIResponseObject(
|
||||
id=self.response_id,
|
||||
created_at=self.created_at,
|
||||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
# Track accumulated usage across all inference calls
|
||||
self.accumulated_usage: OpenAIResponseUsage | None = None
|
||||
|
||||
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
|
||||
cloned: list[OpenAIResponseOutput] = []
|
||||
|
@ -180,6 +149,7 @@ class StreamingResponseOrchestrator:
|
|||
text=self.text,
|
||||
tools=self.ctx.available_tools(),
|
||||
error=error,
|
||||
usage=self.accumulated_usage,
|
||||
)
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
@ -217,6 +187,9 @@ class StreamingResponseOrchestrator:
|
|||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
|
@ -352,6 +325,51 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
|
||||
"""Accumulate usage from a streaming chunk into the response usage format."""
|
||||
if not chunk.usage:
|
||||
return
|
||||
|
||||
if self.accumulated_usage is None:
|
||||
# Convert from chat completion format to response format
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=chunk.usage.prompt_tokens,
|
||||
output_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else None
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Accumulate across multiple inference calls
|
||||
self.accumulated_usage = OpenAIResponseUsage(
|
||||
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
|
||||
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
|
||||
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
|
||||
# Use latest non-null details
|
||||
input_tokens_details=(
|
||||
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
|
||||
if chunk.usage.prompt_tokens_details
|
||||
else self.accumulated_usage.input_tokens_details
|
||||
),
|
||||
output_tokens_details=(
|
||||
OpenAIResponseUsageOutputTokensDetails(
|
||||
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
||||
)
|
||||
if chunk.usage.completion_tokens_details
|
||||
else self.accumulated_usage.output_tokens_details
|
||||
),
|
||||
)
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
|
@ -377,6 +395,10 @@ class StreamingResponseOrchestrator:
|
|||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
|
||||
# Accumulate usage from chunks (typically in final chunk with stream_options)
|
||||
self._accumulate_chunk_usage(chunk)
|
||||
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue