mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix(mypy): resolve tool_executor type issues (45 errors fixed)
- Add proper type annotations using Any where needed - Fix union-attr errors with getattr and walrus operator - Fix arg-type errors for datetime/enum conversions - Add type: ignore for list invariance issues - Remove event variable reuse to satisfy type checker - Use proper type narrowing for tool execution paths Patterns established: - Use getattr() with walrus operator for optional attributes - Use type: ignore for runtime-correct but mypy-incompatible cases - Separate event variables by type to avoid union conflicts
This commit is contained in:
parent
f88416ef87
commit
3a437d80af
1 changed files with 81 additions and 57 deletions
|
|
@ -7,6 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputToolFileSearch,
|
OpenAIResponseInputToolFileSearch,
|
||||||
|
|
@ -22,10 +23,12 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
|
InterleavedContent,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
|
@ -67,7 +70,7 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function = tool_call.function
|
function = tool_call.function
|
||||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
||||||
|
|
||||||
if not function or not tool_call_id or not function.name:
|
if not function or not tool_call_id or not function.name:
|
||||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||||
|
|
@ -84,7 +87,16 @@ class ToolExecutor:
|
||||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||||
|
|
||||||
# Emit completion events for tool execution
|
# Emit completion events for tool execution
|
||||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
has_error = bool(
|
||||||
|
error_exc
|
||||||
|
or (
|
||||||
|
result
|
||||||
|
and (
|
||||||
|
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
||||||
|
or getattr(result, "error_message", None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
async for event_result in self._emit_completion_events(
|
async for event_result in self._emit_completion_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||||
):
|
):
|
||||||
|
|
@ -101,7 +113,11 @@ class ToolExecutor:
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
final_output_message=output_message,
|
final_output_message=output_message,
|
||||||
final_input_message=input_message,
|
final_input_message=input_message,
|
||||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
citation_files=(
|
||||||
|
metadata.get("citation_files")
|
||||||
|
if result and (metadata := getattr(result, "metadata", None))
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_knowledge_search_via_vector_store(
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
|
@ -188,8 +204,9 @@ class ToolExecutor:
|
||||||
|
|
||||||
citation_files[file_id] = filename
|
citation_files[file_id] = filename
|
||||||
|
|
||||||
|
# Cast to proper InterleavedContent type (list invariance)
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=content_items,
|
content=content_items, # type: ignore[arg-type]
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
|
|
@ -209,51 +226,50 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit progress events for tool execution start."""
|
"""Emit progress events for tool execution start."""
|
||||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||||
progress_event = None
|
|
||||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
mcp_progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_progress_event, sequence_number=sequence_number)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
web_progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=web_progress_event, sequence_number=sequence_number)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
file_progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=file_progress_event, sequence_number=sequence_number)
|
||||||
if progress_event:
|
|
||||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
# For web search, emit searching event
|
# For web search, emit searching event
|
||||||
if function_name == "web_search":
|
if function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
web_searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
yield ToolExecutionResult(stream_event=web_searching_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
# For file search, emit searching event
|
# For file search, emit searching event
|
||||||
if function_name == "knowledge_search":
|
if function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
file_searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
yield ToolExecutionResult(stream_event=file_searching_event, sequence_number=sequence_number)
|
||||||
|
|
||||||
async def _execute_tool(
|
async def _execute_tool(
|
||||||
self,
|
self,
|
||||||
|
|
@ -261,7 +277,7 @@ class ToolExecutor:
|
||||||
tool_kwargs: dict,
|
tool_kwargs: dict,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[Exception | None, any]:
|
) -> tuple[Exception | None, Any]:
|
||||||
"""Execute the tool and return error exception and result."""
|
"""Execute the tool and return error exception and result."""
|
||||||
error_exc = None
|
error_exc = None
|
||||||
result = None
|
result = None
|
||||||
|
|
@ -284,9 +300,13 @@ class ToolExecutor:
|
||||||
kwargs=tool_kwargs,
|
kwargs=tool_kwargs,
|
||||||
)
|
)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
response_file_search_tool = next(
|
response_file_search_tool = (
|
||||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
next(
|
||||||
None,
|
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if ctx.response_tools
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if response_file_search_tool:
|
if response_file_search_tool:
|
||||||
# Use vector_stores.search API instead of knowledge_search tool
|
# Use vector_stores.search API instead of knowledge_search tool
|
||||||
|
|
@ -322,35 +342,34 @@ class ToolExecutor:
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit completion or failure events for tool execution."""
|
"""Emit completion or failure events for tool execution."""
|
||||||
completion_event = None
|
|
||||||
|
|
||||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
if has_error:
|
if has_error:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
||||||
else:
|
else:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
||||||
if completion_event:
|
|
||||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
async def _build_result_messages(
|
async def _build_result_messages(
|
||||||
self,
|
self,
|
||||||
|
|
@ -360,21 +379,18 @@ class ToolExecutor:
|
||||||
tool_kwargs: dict,
|
tool_kwargs: dict,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
error_exc: Exception | None,
|
error_exc: Exception | None,
|
||||||
result: any,
|
result: Any,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[any, any]:
|
) -> tuple[Any, Any]:
|
||||||
"""Build output and input messages from tool execution results."""
|
"""Build output and input messages from tool execution results."""
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build output message
|
# Build output message
|
||||||
|
message: Any
|
||||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
message = OpenAIResponseOutputMessageMCPCall(
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
id=item_id,
|
id=item_id,
|
||||||
arguments=function.arguments,
|
arguments=function.arguments,
|
||||||
|
|
@ -383,10 +399,14 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
if error_exc:
|
if error_exc:
|
||||||
message.error = str(error_exc)
|
message.error = str(error_exc)
|
||||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
elif (
|
||||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
result and (error_code := getattr(result, "error_code", None)) and error_code > 0
|
||||||
elif result and result.content:
|
) or (result and (error_message := getattr(result, "error_message", None))):
|
||||||
message.output = interleaved_content_as_str(result.content)
|
ec = getattr(result, "error_code", "unknown")
|
||||||
|
em = getattr(result, "error_message", "")
|
||||||
|
message.error = f"Error (code {ec}): {em}"
|
||||||
|
elif result and (content := getattr(result, "content", None)):
|
||||||
|
message.output = interleaved_content_as_str(content)
|
||||||
else:
|
else:
|
||||||
if function.name == "web_search":
|
if function.name == "web_search":
|
||||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
|
@ -401,17 +421,17 @@ class ToolExecutor:
|
||||||
queries=[tool_kwargs.get("query", "")],
|
queries=[tool_kwargs.get("query", "")],
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
if result and "document_ids" in result.metadata:
|
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
||||||
message.results = []
|
message.results = []
|
||||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
score = metadata["scores"][i] if "scores" in metadata else None
|
||||||
message.results.append(
|
message.results.append(
|
||||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||||
file_id=doc_id,
|
file_id=doc_id,
|
||||||
filename=doc_id,
|
filename=doc_id,
|
||||||
text=text,
|
text=text if text is not None else "",
|
||||||
score=score,
|
score=score if score is not None else 0.0,
|
||||||
attributes={},
|
attributes={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -421,27 +441,31 @@ class ToolExecutor:
|
||||||
raise ValueError(f"Unknown tool {function.name} called")
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
# Build input message
|
# Build input message
|
||||||
input_message = None
|
input_message: OpenAIToolMessageParam | None = None
|
||||||
if result and result.content:
|
if result and (result_content := getattr(result, "content", None)):
|
||||||
if isinstance(result.content, str):
|
if isinstance(result_content, str):
|
||||||
content = result.content
|
msg_content: str | list[Any] = result_content
|
||||||
elif isinstance(result.content, list):
|
elif isinstance(result_content, list):
|
||||||
content = []
|
content_list: list[Any] = []
|
||||||
for item in result.content:
|
for item in result_content:
|
||||||
|
part: Any
|
||||||
if isinstance(item, TextContentItem):
|
if isinstance(item, TextContentItem):
|
||||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
elif isinstance(item, ImageContentItem):
|
elif isinstance(item, ImageContentItem):
|
||||||
if item.image.data:
|
if item.image.data:
|
||||||
url = f"data:image;base64,{item.image.data}"
|
url_value = f"data:image;base64,{item.image.data}"
|
||||||
else:
|
else:
|
||||||
url = item.image.url
|
url_value = str(item.image.url) if item.image.url else ""
|
||||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
content.append(part)
|
content_list.append(part)
|
||||||
|
msg_content = content_list
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
raise ValueError(f"Unknown result content type: {type(result_content)}")
|
||||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
|
||||||
|
# This is runtime-safe as the API accepts it, but mypy complains
|
||||||
|
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue