mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
minor fix
This commit is contained in:
parent
1143db0f64
commit
376f0fcd23
2 changed files with 79 additions and 34 deletions
|
|
@ -11,7 +11,7 @@ from typing import Any
|
|||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ApprovalFilter,
|
||||
MCPAuthentication,
|
||||
MCPAuthorization,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseContentPartReasoningText,
|
||||
|
|
@ -69,7 +69,9 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import (
|
||||
|
|
@ -81,14 +83,14 @@ from .utils import (
|
|||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
def _convert_authentication_to_headers(auth: MCPAuthentication) -> dict[str, str]:
|
||||
"""Convert MCPAuthentication config to HTTP headers.
|
||||
def _convert_authentication_to_headers(auth: MCPAuthorization) -> dict[str, str]:
|
||||
"""Convert MCPAuthorization config to HTTP headers.
|
||||
|
||||
Args:
|
||||
auth: Authentication configuration
|
||||
auth: Authorization configuration
|
||||
|
||||
Returns:
|
||||
Dictionary of HTTP headers for authentication
|
||||
Dictionary of HTTP headers for authorization
|
||||
"""
|
||||
headers = {}
|
||||
|
||||
|
|
@ -120,7 +122,9 @@ def convert_tooldef_to_chat_tool(tool_def):
|
|||
"""
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
|
||||
internal_tool_def = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
|
|
@ -298,7 +302,9 @@ class StreamingResponseOrchestrator:
|
|||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
tool_call.function.name,
|
||||
tool_call.function.arguments,
|
||||
output_messages,
|
||||
):
|
||||
yield evt
|
||||
|
||||
|
|
@ -407,7 +413,12 @@ class StreamingResponseOrchestrator:
|
|||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
return (
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
approvals,
|
||||
next_turn_messages,
|
||||
)
|
||||
|
||||
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
|
||||
"""Accumulate usage from a streaming chunk into the response usage format."""
|
||||
|
|
@ -718,12 +729,15 @@ class StreamingResponseOrchestrator:
|
|||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
|
||||
if not is_mcp_tool and tool_call.function.name not in [
|
||||
"web_search",
|
||||
"knowledge_search",
|
||||
]:
|
||||
# for MCP tools (and even other non-function tools) we emit an output message item later
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
name=(tool_call.function.name if tool_call.function else ""),
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
|
|
@ -1035,14 +1049,18 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
async def _process_new_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
self,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
|
|
@ -1079,7 +1097,9 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
|
||||
async def _process_mcp_tool(
|
||||
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
|
||||
self,
|
||||
mcp_tool: OpenAIResponseInputToolMCP,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
|
@ -1108,10 +1128,10 @@ class StreamingResponseOrchestrator:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"mcp_list_tools_id": list_id,
|
||||
}
|
||||
# Prepare headers with authentication from tool config
|
||||
# Prepare headers with authorization from tool config
|
||||
headers = dict(mcp_tool.headers or {})
|
||||
if mcp_tool.authentication:
|
||||
auth_headers = _convert_authentication_to_headers(mcp_tool.authentication)
|
||||
if mcp_tool.authorization:
|
||||
auth_headers = _convert_authentication_to_headers(mcp_tool.authorization)
|
||||
# Don't override existing headers (case-insensitive check)
|
||||
existing_keys_lower = {k.lower() for k in headers.keys()}
|
||||
for key, value in auth_headers.items():
|
||||
|
|
@ -1200,7 +1220,10 @@ class StreamingResponseOrchestrator:
|
|||
return True
|
||||
|
||||
async def _add_mcp_approval_request(
|
||||
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: str,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
|
||||
|
|
@ -1227,7 +1250,9 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
async def _add_mcp_list_tools(
|
||||
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
self,
|
||||
mcp_list_message: OpenAIResponseOutputMessageMCPListTools,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
|
@ -1260,11 +1285,15 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
async def _reuse_mcp_list_tools(
|
||||
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
|
||||
self,
|
||||
original: OpenAIResponseOutputMessageMCPListTools,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
for t in original.tools:
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
|
||||
# convert from input_schema to map of ToolParamDefinitions...
|
||||
tool_def = ToolDefinition(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from collections.abc import AsyncIterator
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPAuthentication,
|
||||
MCPAuthorization,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
|
||||
|
|
@ -27,10 +27,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
|
|
@ -48,8 +45,8 @@ from .types import ChatCompletionContext, ToolExecutionResult
|
|||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
def _convert_authentication_to_headers(auth: MCPAuthentication) -> dict[str, str]:
|
||||
"""Convert MCPAuthentication config to HTTP headers.
|
||||
def _convert_authentication_to_headers(auth: MCPAuthorization) -> dict[str, str]:
|
||||
"""Convert MCPAuthorization config to HTTP headers.
|
||||
|
||||
Args:
|
||||
auth: Authentication configuration
|
||||
|
|
@ -106,7 +103,12 @@ class ToolExecutor:
|
|||
|
||||
# Emit progress events for tool execution start
|
||||
async for event_result in self._emit_progress_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||
function.name,
|
||||
ctx,
|
||||
sequence_number,
|
||||
output_index,
|
||||
item_id,
|
||||
mcp_tool_to_server,
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
|
@ -126,14 +128,28 @@ class ToolExecutor:
|
|||
)
|
||||
)
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
function.name,
|
||||
ctx,
|
||||
sequence_number,
|
||||
output_index,
|
||||
item_id,
|
||||
has_error,
|
||||
mcp_tool_to_server,
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
function,
|
||||
tool_call_id,
|
||||
item_id,
|
||||
tool_kwargs,
|
||||
ctx,
|
||||
error_exc,
|
||||
result,
|
||||
has_error,
|
||||
mcp_tool_to_server,
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
|
|
@ -328,10 +344,10 @@ class ToolExecutor:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
# Prepare headers with authentication from tool config
|
||||
# Prepare headers with authorization from tool config
|
||||
headers = dict(mcp_tool.headers or {})
|
||||
if mcp_tool.authentication:
|
||||
auth_headers = _convert_authentication_to_headers(mcp_tool.authentication)
|
||||
if mcp_tool.authorization:
|
||||
auth_headers = _convert_authentication_to_headers(mcp_tool.authorization)
|
||||
# Don't override existing headers (case-insensitive check)
|
||||
existing_keys_lower = {k.lower() for k in headers.keys()}
|
||||
for key, value in auth_headers.items():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue