Merge a93130e323 into sapling-pr-archive-ehhuang

This commit is contained in:
ehhuang 2025-10-09 13:53:45 -07:00 committed by GitHub
commit 9e70492078
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 19 deletions

View file

@ -47,6 +47,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIMessageParam,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ChatCompletionResult from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call from .utils import convert_chat_choice_to_response_message, is_function_tool_call
@ -597,14 +598,22 @@ class StreamingResponseOrchestrator:
never_allowed = mcp_tool.allowed_tools.never never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools # Call list_mcp_tools
tool_defs = await list_mcp_tools( tool_defs = None
endpoint=mcp_tool.server_url, list_id = f"mcp_list_{uuid.uuid4()}"
headers=mcp_tool.headers or {}, attributes = {
) "server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}
async with tracing.span("list_mcp_tools", attributes):
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message # Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools( mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}", id=list_id,
server_label=mcp_tool.server_label, server_label=mcp_tool.server_label,
tools=[], tools=[],
) )

View file

@ -35,6 +35,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ToolExecutionResult from .types import ChatCompletionContext, ToolExecutionResult
@ -251,12 +252,18 @@ class ToolExecutor:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name] mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool( attributes = {
endpoint=mcp_tool.server_url, "server_label": mcp_tool.server_label,
headers=mcp_tool.headers or {}, "server_url": mcp_tool.server_url,
tool_name=function_name, "tool_name": function_name,
kwargs=tool_kwargs, }
) async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
response_file_search_tool = next( response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
@ -266,15 +273,20 @@ class ToolExecutor:
# Use vector_stores.search API instead of knowledge_search tool # Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options # to support filters and ranking_options
query = tool_kwargs.get("query", "") query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store( async with tracing.span("knowledge_search", {}):
query=query, result = await self._execute_knowledge_search_via_vector_store(
response_file_search_tool=response_file_search_tool, query=query,
) response_file_search_tool=response_file_search_tool,
)
else: else:
result = await self.tool_runtime_api.invoke_tool( attributes = {
tool_name=function_name, "tool_name": function_name,
kwargs=tool_kwargs, }
) async with tracing.span("invoke_tool", attributes):
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e: except Exception as e:
error_exc = e error_exc = e