mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
minor linting change
This commit is contained in:
parent
1db14ca4a3
commit
59793ac63b
1 changed files with 88 additions and 38 deletions
|
|
@ -26,10 +26,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||||
ImageContentItem,
|
|
||||||
TextContentItem,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
|
@ -69,7 +66,9 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function = tool_call.function
|
function = tool_call.function
|
||||||
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
tool_kwargs = (
|
||||||
|
json.loads(function.arguments) if function and function.arguments else {}
|
||||||
|
)
|
||||||
|
|
||||||
if not function or not tool_call_id or not function.name:
|
if not function or not tool_call_id or not function.name:
|
||||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||||
|
|
@ -77,13 +76,20 @@ class ToolExecutor:
|
||||||
|
|
||||||
# Emit progress events for tool execution start
|
# Emit progress events for tool execution start
|
||||||
async for event_result in self._emit_progress_events(
|
async for event_result in self._emit_progress_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
function.name,
|
||||||
|
ctx,
|
||||||
|
sequence_number,
|
||||||
|
output_index,
|
||||||
|
item_id,
|
||||||
|
mcp_tool_to_server,
|
||||||
):
|
):
|
||||||
sequence_number = event_result.sequence_number
|
sequence_number = event_result.sequence_number
|
||||||
yield event_result
|
yield event_result
|
||||||
|
|
||||||
# Execute the actual tool call
|
# Execute the actual tool call
|
||||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
error_exc, result = await self._execute_tool(
|
||||||
|
function.name, tool_kwargs, ctx, mcp_tool_to_server
|
||||||
|
)
|
||||||
|
|
||||||
# Emit completion events for tool execution
|
# Emit completion events for tool execution
|
||||||
has_error = bool(
|
has_error = bool(
|
||||||
|
|
@ -91,20 +97,23 @@ class ToolExecutor:
|
||||||
or (
|
or (
|
||||||
result
|
result
|
||||||
and (
|
and (
|
||||||
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
(
|
||||||
|
(error_code := getattr(result, "error_code", None))
|
||||||
|
and error_code > 0
|
||||||
|
)
|
||||||
or getattr(result, "error_message", None)
|
or getattr(result, "error_message", None)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async for event_result in self._emit_completion_events(
|
async for event_result in self._emit_completion_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server,
|
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||||
):
|
):
|
||||||
sequence_number = event_result.sequence_number
|
sequence_number = event_result.sequence_number
|
||||||
yield event_result
|
yield event_result
|
||||||
|
|
||||||
# Build result messages from tool execution
|
# Build result messages from tool execution
|
||||||
output_message, input_message = await self._build_result_messages(
|
output_message, input_message = await self._build_result_messages(
|
||||||
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server,
|
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield the final result
|
# Yield the final result
|
||||||
|
|
@ -113,7 +122,9 @@ class ToolExecutor:
|
||||||
final_output_message=output_message,
|
final_output_message=output_message,
|
||||||
final_input_message=input_message,
|
final_input_message=input_message,
|
||||||
citation_files=(
|
citation_files=(
|
||||||
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
|
metadata.get("citation_files")
|
||||||
|
if result and (metadata := getattr(result, "metadata", None))
|
||||||
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -142,7 +153,10 @@ class ToolExecutor:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Run all searches in parallel using gather
|
# Run all searches in parallel using gather
|
||||||
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
search_tasks = [
|
||||||
|
search_single_store(vid)
|
||||||
|
for vid in response_file_search_tool.vector_store_ids
|
||||||
|
]
|
||||||
all_results = await asyncio.gather(*search_tasks)
|
all_results = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
# Flatten results
|
# Flatten results
|
||||||
|
|
@ -161,17 +175,23 @@ class ToolExecutor:
|
||||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||||
# Get file_id from attributes if result_item.file_id is empty
|
# Get file_id from attributes if result_item.file_id is empty
|
||||||
file_id = result_item.file_id or (
|
file_id = result_item.file_id or (
|
||||||
result_item.attributes.get("document_id") if result_item.attributes else None
|
result_item.attributes.get("document_id")
|
||||||
|
if result_item.attributes
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
|
||||||
if result_item.attributes:
|
if result_item.attributes:
|
||||||
metadata_text += f", attributes: {result_item.attributes}"
|
metadata_text += f", attributes: {result_item.attributes}"
|
||||||
|
|
||||||
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
text_content = (
|
||||||
|
f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
|
||||||
|
)
|
||||||
content_items.append(TextContentItem(text=text_content))
|
content_items.append(TextContentItem(text=text_content))
|
||||||
unique_files.add(file_id)
|
unique_files.add(file_id)
|
||||||
|
|
||||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
content_items.append(
|
||||||
|
TextContentItem(text="END of knowledge_search tool results.\n")
|
||||||
|
)
|
||||||
|
|
||||||
citation_instruction = ""
|
citation_instruction = ""
|
||||||
if unique_files:
|
if unique_files:
|
||||||
|
|
@ -206,7 +226,9 @@ class ToolExecutor:
|
||||||
content=content_items, # type: ignore[arg-type]
|
content=content_items, # type: ignore[arg-type]
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [
|
||||||
|
r.content[0].text if r.content else "" for r in search_results
|
||||||
|
],
|
||||||
"scores": [r.score for r in search_results],
|
"scores": [r.score for r in search_results],
|
||||||
"citation_files": citation_files,
|
"citation_files": citation_files,
|
||||||
},
|
},
|
||||||
|
|
@ -317,7 +339,11 @@ class ToolExecutor:
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
response_file_search_tool = (
|
response_file_search_tool = (
|
||||||
next(
|
next(
|
||||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
(
|
||||||
|
t
|
||||||
|
for t in ctx.response_tools
|
||||||
|
if isinstance(t, OpenAIResponseInputToolFileSearch)
|
||||||
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if ctx.response_tools
|
if ctx.response_tools
|
||||||
|
|
@ -363,28 +389,42 @@ class ToolExecutor:
|
||||||
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
yield ToolExecutionResult(
|
||||||
|
stream_event=mcp_failed_event, sequence_number=sequence_number
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
mcp_completed_event = (
|
||||||
|
OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
)
|
||||||
|
yield ToolExecutionResult(
|
||||||
|
stream_event=mcp_completed_event, sequence_number=sequence_number
|
||||||
|
)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
web_completion_event = (
|
||||||
|
OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
)
|
||||||
|
yield ToolExecutionResult(
|
||||||
|
stream_event=web_completion_event, sequence_number=sequence_number
|
||||||
|
)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
file_completion_event = (
|
||||||
|
OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
)
|
||||||
|
yield ToolExecutionResult(
|
||||||
|
stream_event=file_completion_event, sequence_number=sequence_number
|
||||||
|
)
|
||||||
|
|
||||||
async def _build_result_messages(
|
async def _build_result_messages(
|
||||||
self,
|
self,
|
||||||
|
|
@ -414,9 +454,11 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
if error_exc:
|
if error_exc:
|
||||||
message.error = str(error_exc)
|
message.error = str(error_exc)
|
||||||
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
|
elif (
|
||||||
result and getattr(result, "error_message", None)
|
result
|
||||||
):
|
and (error_code := getattr(result, "error_code", None))
|
||||||
|
and error_code > 0
|
||||||
|
) or (result and getattr(result, "error_message", None)):
|
||||||
ec = getattr(result, "error_code", "unknown")
|
ec = getattr(result, "error_code", "unknown")
|
||||||
em = getattr(result, "error_message", "")
|
em = getattr(result, "error_message", "")
|
||||||
message.error = f"Error (code {ec}): {em}"
|
message.error = f"Error (code {ec}): {em}"
|
||||||
|
|
@ -436,7 +478,11 @@ class ToolExecutor:
|
||||||
queries=[tool_kwargs.get("query", "")],
|
queries=[tool_kwargs.get("query", "")],
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
if (
|
||||||
|
result
|
||||||
|
and (metadata := getattr(result, "metadata", None))
|
||||||
|
and "document_ids" in metadata
|
||||||
|
):
|
||||||
message.results = []
|
message.results = []
|
||||||
for i, doc_id in enumerate(metadata["document_ids"]):
|
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||||
text = metadata["chunks"][i] if "chunks" in metadata else None
|
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||||
|
|
@ -472,7 +518,9 @@ class ToolExecutor:
|
||||||
url_value = f"data:image;base64,{item.image.data}"
|
url_value = f"data:image;base64,{item.image.data}"
|
||||||
else:
|
else:
|
||||||
url_value = str(item.image.url) if item.image.url else ""
|
url_value = str(item.image.url) if item.image.url else ""
|
||||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
part = OpenAIChatCompletionContentPartImageParam(
|
||||||
|
image_url=OpenAIImageURL(url=url_value)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
content_list.append(part)
|
content_list.append(part)
|
||||||
|
|
@ -484,6 +532,8 @@ class ToolExecutor:
|
||||||
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(
|
||||||
|
content=text, tool_call_id=tool_call_id
|
||||||
|
)
|
||||||
|
|
||||||
return message, input_message
|
return message, input_message
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue