minor linting change

This commit is contained in:
Omar Abdelwahab 2025-11-04 12:51:19 -08:00
parent 1db14ca4a3
commit 59793ac63b

View file

@ -26,10 +26,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
@ -69,7 +66,9 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id tool_call_id = tool_call.id
function = tool_call.function function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {} tool_kwargs = (
json.loads(function.arguments) if function and function.arguments else {}
)
if not function or not tool_call_id or not function.name: if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number) yield ToolExecutionResult(sequence_number=sequence_number)
@ -77,13 +76,20 @@ class ToolExecutor:
# Emit progress events for tool execution start # Emit progress events for tool execution start
async for event_result in self._emit_progress_events( async for event_result in self._emit_progress_events(
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server function.name,
ctx,
sequence_number,
output_index,
item_id,
mcp_tool_to_server,
): ):
sequence_number = event_result.sequence_number sequence_number = event_result.sequence_number
yield event_result yield event_result
# Execute the actual tool call # Execute the actual tool call
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) error_exc, result = await self._execute_tool(
function.name, tool_kwargs, ctx, mcp_tool_to_server
)
# Emit completion events for tool execution # Emit completion events for tool execution
has_error = bool( has_error = bool(
@ -91,20 +97,23 @@ class ToolExecutor:
or ( or (
result result
and ( and (
((error_code := getattr(result, "error_code", None)) and error_code > 0) (
(error_code := getattr(result, "error_code", None))
and error_code > 0
)
or getattr(result, "error_message", None) or getattr(result, "error_message", None)
) )
) )
) )
async for event_result in self._emit_completion_events( async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server, function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
): ):
sequence_number = event_result.sequence_number sequence_number = event_result.sequence_number
yield event_result yield event_result
# Build result messages from tool execution # Build result messages from tool execution
output_message, input_message = await self._build_result_messages( output_message, input_message = await self._build_result_messages(
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server, function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
) )
# Yield the final result # Yield the final result
@ -113,7 +122,9 @@ class ToolExecutor:
final_output_message=output_message, final_output_message=output_message,
final_input_message=input_message, final_input_message=input_message,
citation_files=( citation_files=(
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None metadata.get("citation_files")
if result and (metadata := getattr(result, "metadata", None))
else None
), ),
) )
@ -142,7 +153,10 @@ class ToolExecutor:
return [] return []
# Run all searches in parallel using gather # Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] search_tasks = [
search_single_store(vid)
for vid in response_file_search_tool.vector_store_ids
]
all_results = await asyncio.gather(*search_tasks) all_results = await asyncio.gather(*search_tasks)
# Flatten results # Flatten results
@ -161,17 +175,23 @@ class ToolExecutor:
chunk_text = result_item.content[0].text if result_item.content else "" chunk_text = result_item.content[0].text if result_item.content else ""
# Get file_id from attributes if result_item.file_id is empty # Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or ( file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None result_item.attributes.get("document_id")
if result_item.attributes
else None
) )
metadata_text = f"document_id: {file_id}, score: {result_item.score}" metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes: if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}" metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n" text_content = (
f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
)
content_items.append(TextContentItem(text=text_content)) content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id) unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) content_items.append(
TextContentItem(text="END of knowledge_search tool results.\n")
)
citation_instruction = "" citation_instruction = ""
if unique_files: if unique_files:
@ -206,7 +226,9 @@ class ToolExecutor:
content=content_items, # type: ignore[arg-type] content=content_items, # type: ignore[arg-type]
metadata={ metadata={
"document_ids": [r.file_id for r in search_results], "document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results], "chunks": [
r.content[0].text if r.content else "" for r in search_results
],
"scores": [r.score for r in search_results], "scores": [r.score for r in search_results],
"citation_files": citation_files, "citation_files": citation_files,
}, },
@ -317,7 +339,11 @@ class ToolExecutor:
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
response_file_search_tool = ( response_file_search_tool = (
next( next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), (
t
for t in ctx.response_tools
if isinstance(t, OpenAIResponseInputToolFileSearch)
),
None, None,
) )
if ctx.response_tools if ctx.response_tools
@ -363,28 +389,42 @@ class ToolExecutor:
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed( mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number) yield ToolExecutionResult(
else: stream_event=mcp_failed_event, sequence_number=sequence_number
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( )
sequence_number=sequence_number, else:
mcp_completed_event = (
OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
)
yield ToolExecutionResult(
stream_event=mcp_completed_event, sequence_number=sequence_number
) )
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
elif function_name == "web_search": elif function_name == "web_search":
sequence_number += 1 sequence_number += 1
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( web_completion_event = (
item_id=item_id, OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
output_index=output_index, item_id=item_id,
sequence_number=sequence_number, output_index=output_index,
sequence_number=sequence_number,
)
)
yield ToolExecutionResult(
stream_event=web_completion_event, sequence_number=sequence_number
) )
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( file_completion_event = (
item_id=item_id, OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
output_index=output_index, item_id=item_id,
sequence_number=sequence_number, output_index=output_index,
sequence_number=sequence_number,
)
)
yield ToolExecutionResult(
stream_event=file_completion_event, sequence_number=sequence_number
) )
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
async def _build_result_messages( async def _build_result_messages(
self, self,
@ -414,9 +454,11 @@ class ToolExecutor:
) )
if error_exc: if error_exc:
message.error = str(error_exc) message.error = str(error_exc)
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or ( elif (
result and getattr(result, "error_message", None) result
): and (error_code := getattr(result, "error_code", None))
and error_code > 0
) or (result and getattr(result, "error_message", None)):
ec = getattr(result, "error_code", "unknown") ec = getattr(result, "error_code", "unknown")
em = getattr(result, "error_message", "") em = getattr(result, "error_message", "")
message.error = f"Error (code {ec}): {em}" message.error = f"Error (code {ec}): {em}"
@ -436,7 +478,11 @@ class ToolExecutor:
queries=[tool_kwargs.get("query", "")], queries=[tool_kwargs.get("query", "")],
status="completed", status="completed",
) )
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata: if (
result
and (metadata := getattr(result, "metadata", None))
and "document_ids" in metadata
):
message.results = [] message.results = []
for i, doc_id in enumerate(metadata["document_ids"]): for i, doc_id in enumerate(metadata["document_ids"]):
text = metadata["chunks"][i] if "chunks" in metadata else None text = metadata["chunks"][i] if "chunks" in metadata else None
@ -472,7 +518,9 @@ class ToolExecutor:
url_value = f"data:image;base64,{item.image.data}" url_value = f"data:image;base64,{item.image.data}"
else: else:
url_value = str(item.image.url) if item.image.url else "" url_value = str(item.image.url) if item.image.url else ""
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value)) part = OpenAIChatCompletionContentPartImageParam(
image_url=OpenAIImageURL(url=url_value)
)
else: else:
raise ValueError(f"Unknown result content type: {type(item)}") raise ValueError(f"Unknown result content type: {type(item)}")
content_list.append(part) content_list.append(part)
@ -484,6 +532,8 @@ class ToolExecutor:
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
else: else:
text = str(error_exc) if error_exc else "Tool execution failed" text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) input_message = OpenAIToolMessageParam(
content=text, tool_call_id=tool_call_id
)
return message, input_message return message, input_message