fix(mypy-cleanup): part-01 resolve meta reference agent type issues (126 errors) (#3945)

Error fixes in Agents implementation (`meta-reference` provider) --
adding proper type annotations and using type narrowing for optional
attributes. Essentially a bunch of `if x and x_foo := getattr(x, "foo")`
instead of `x.foo` directly

Part of ongoing mypy remediation effort.

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-29 07:54:30 -07:00 committed by GitHub
parent 22bf0d0471
commit ce31aa1704
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 204 additions and 143 deletions

View file

@ -11,6 +11,7 @@ import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any
import httpx import httpx
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
) )
def turn_to_messages(self, turn: Turn) -> list[Message]: def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = [] messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set() tool_call_ids = set()
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses: for response in step.tool_responses:
tool_call_ids.add(response.call_id) tool_call_ids.add(response.call_id)
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.append(msg) messages.append(msg)
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.inference.value: if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
messages.append(step.model_response) messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value: elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses: for response in step.tool_responses:
messages.append( messages.append(
ToolResponseMessage( ToolResponseMessage(
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
content=response.content, content=response.content,
) )
) )
elif step.step_type == StepType.shield_call.value: elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
if step.violation: if step.violation and step.violation.user_message:
# CompletionMessage itself in the ShieldResponse # CompletionMessage itself in the ShieldResponse
messages.append( messages.append(
CompletionMessage( CompletionMessage(
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name) return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = [] messages: list[Message] = []
if self.agent_config.instructions != "": if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions)) messages.append(SystemMessage(content=self.agent_config.instructions))
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
steps = [] steps = []
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
if is_resume: if is_resume:
assert isinstance(request, AgentTurnResumeRequest)
tool_response_messages = [ tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
] ]
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id request.session_id, request.turn_id
) )
now = datetime.now(UTC).isoformat() now_dt = datetime.now(UTC)
tool_execution_step = ToolExecutionStep( tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses, tool_responses=request.tool_responses,
completed_at=now, completed_at=now_dt,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
) )
steps.append(tool_execution_step) steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=tool_execution_step.step_id, step_id=tool_execution_step.step_id,
step_details=tool_execution_step, step_details=tool_execution_step,
) )
) )
) )
input_messages = last_turn.input_messages # Cast needed due to list invariance - last_turn.input_messages is the right type
input_messages = last_turn.input_messages # type: ignore[assignment]
turn_id = request.turn_id actual_turn_id = request.turn_id
start_time = last_turn.started_at start_time = last_turn.started_at
else: else:
assert isinstance(request, AgentTurnCreateRequest)
messages.extend(request.messages) messages.extend(request.messages)
start_time = datetime.now(UTC).isoformat() start_time = datetime.now(UTC)
input_messages = request.messages # Cast needed due to list invariance - request.messages is the right type
input_messages = request.messages # type: ignore[assignment]
# Use the generated turn_id from beginning of function
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
output_message = None output_message = None
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
req_sampling = (
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
)
async for chunk in self.run( async for chunk in self.run(
session_id=request.session_id, session_id=request.session_id,
turn_id=turn_id, turn_id=actual_turn_id,
input_messages=messages, input_messages=messages,
sampling_params=self.agent_config.sampling_params, sampling_params=req_sampling,
stream=request.stream, stream=request.stream,
documents=request.documents if not is_resume else None, documents=req_documents,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
output_message = chunk output_message = chunk
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
steps.append(event.payload.step_details) event.payload, "step_details"
):
step_details = event.payload.step_details
steps.append(step_details)
yield chunk yield chunk
assert output_message is not None assert output_message is not None
turn = Turn( turn = Turn(
turn_id=turn_id, turn_id=actual_turn_id,
session_id=request.session_id, session_id=request.session_id,
input_messages=input_messages, input_messages=input_messages, # type: ignore[arg-type]
output_message=output_message, output_message=output_message,
started_at=start_time, started_at=start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
steps=steps, steps=steps,
) )
await self.storage.add_turn_to_session(request.session_id, turn) await self.storage.add_turn_to_session(request.session_id, turn)
@ -345,7 +361,7 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a # return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it. # final boolean (to see whether an exception happened) and then explicitly testing for it.
if len(self.input_shields) > 0: if self.input_shields:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input" turn_id, input_messages, self.input_shields, "user-input"
): ):
@ -374,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination # for output shields run on the full input and output combination
messages = input_messages + [final_response] messages = input_messages + [final_response]
if len(self.output_shields) > 0: if self.output_shields:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output" turn_id, messages, self.output_shields, "assistant-output"
): ):
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
return return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(UTC).isoformat() shield_call_start_time = datetime.now(UTC)
try: try:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call,
step_id=step_id, step_id=step_id,
metadata=dict(touchpoint=touchpoint), metadata=dict(touchpoint=touchpoint),
) )
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call,
step_id=step_id, step_id=step_id,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
violation=e.violation, violation=e.violation,
started_at=shield_call_start_time, started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
), ),
) )
) )
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call,
step_id=step_id, step_id=step_id,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
violation=None, violation=None,
started_at=shield_call_start_time, started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
), ),
) )
) )
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id) self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
output_attachments = [] output_attachments: list[Attachment] = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
client_tools = {} client_tools = {}
for tool in self.agent_config.client_tools: if self.agent_config.client_tools:
client_tools[tool.name] = tool for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True: while True:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat() inference_start_time = datetime.now(UTC)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
) )
) )
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
return value return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam: def _add_type(openai_msg: Any) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts # Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg) openai_msg = _serialize_nested(openai_msg)
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
messages=openai_messages, messages=openai_messages,
tools=openai_tools if openai_tools else None, tools=openai_tools if openai_tools else None,
tool_choice=tool_choice, tool_choice=tool_choice,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format, # type: ignore[arg-type]
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
# Convert OpenAI stream back to Llama Stack format # Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream( response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True openai_stream, # type: ignore[arg-type]
enable_incremental_tool_calls=True,
) )
async for chunk in response_stream: async for chunk in response_stream:
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
delta=delta, delta=delta,
) )
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
delta=delta, delta=delta,
) )
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
output_attr = json.dumps( output_attr = json.dumps(
{ {
"content": content, "content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], "tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
],
} }
) )
span.set_attribute("output", output_attr) span.set_attribute("output", output_attr)
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
if tool_calls: if tool_calls:
content = "" content = ""
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
message = CompletionMessage( message = CompletionMessage(
content=content, content=content,
stop_reason=stop_reason, stop_reason=stop_reason,
tool_calls=tool_calls, tool_calls=valid_tool_calls if valid_tool_calls else None,
) )
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
step_details=InferenceStep( step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some # somewhere deep, we are re-assigning message or closing over some
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id, turn_id=turn_id,
model_response=copy.deepcopy(message), model_response=copy.deepcopy(message),
started_at=inference_start_time, started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
), ),
) )
) )
) )
if n_iter >= self.agent_config.max_infer_iters: max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
if n_iter >= max_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.") logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn # NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point # Do not continue the tool call loop after this point
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
yield message yield message
break break
if len(message.tool_calls) == 0: if not message.tool_calls or len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn: if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0: if len(output_attachments) > 0:
if isinstance(message.content, list): if isinstance(message.content, list):
message.content += output_attachments # List invariance - attachments are compatible at runtime
message.content += output_attachments # type: ignore[arg-type]
else: else:
message.content = [message.content] + output_attachments # List invariance - attachments are compatible at runtime
message.content = [message.content] + output_attachments # type: ignore[assignment]
yield message yield message
else: else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
@ -725,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin):
non_client_tool_calls = [] non_client_tool_calls = []
# Separate client and non-client tool calls # Separate client and non-client tool calls
for tool_call in message.tool_calls: if message.tool_calls:
if tool_call.tool_name in client_tools: for tool_call in message.tool_calls:
client_tool_calls.append(tool_call) if tool_call.tool_name in client_tools:
else: client_tool_calls.append(tool_call)
non_client_tool_calls.append(tool_call) else:
non_client_tool_calls.append(tool_call)
# Process non-client tool calls first # Process non-client tool calls first
for tool_call in non_client_tool_calls: for tool_call in non_client_tool_calls:
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=step_id, step_id=step_id,
) )
) )
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=step_id, step_id=step_id,
delta=ToolCallDelta( delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.telemetry_enabled if self.telemetry_enabled
else {}, else {},
) as span: ) as span:
tool_execution_start_time = datetime.now(UTC).isoformat() tool_execution_start_time = datetime.now(UTC)
tool_result = await self.execute_tool_call_maybe( tool_result = await self.execute_tool_call_maybe(
session_id, session_id,
tool_call, tool_call,
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
) )
], ],
started_at=tool_execution_start_time, started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
) )
# Yield the step completion event # Yield the step completion event
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=step_id, step_id=step_id,
step_details=tool_execution_step, step_details=tool_execution_step,
) )
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id, turn_id=turn_id,
tool_calls=client_tool_calls, tool_calls=client_tool_calls,
tool_responses=[], tool_responses=[],
started_at=datetime.now(UTC).isoformat(), started_at=datetime.now(UTC),
), ),
) )
@ -868,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {} toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {} tool_name_to_def: dict[str, ToolDefinition] = {}
tool_name_to_args = {} tool_name_to_args = {}
for tool_def in self.agent_config.client_tools: if self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None): for tool_def in self.agent_config.client_tools:
raise ValueError(f"Tool {tool_def.name} already exists") if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
# Use input_schema from ToolDef directly # Use input_schema from ToolDef directly
tool_name_to_def[tool_def.name] = ToolDefinition( tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name, tool_name=tool_def.name,
description=tool_def.description, description=tool_def.description,
input_schema=tool_def.input_schema, input_schema=tool_def.input_schema,
) )
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
identifier = None identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier: if identifier:
tool_name_to_def[identifier] = ToolDefinition( # Convert BuiltinTool to string for dictionary key
tool_name=identifier, identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
if tool_name_to_def.get(identifier_str, None):
raise ValueError(f"Tool {identifier_str} already exists")
tool_name_to_def[identifier_str] = ToolDefinition(
tool_name=identifier_str,
description=tool_def.description, description=tool_def.description,
input_schema=tool_def.input_schema, input_schema=tool_def.input_schema,
) )
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {}) tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = ( self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()), list(tool_name_to_def.values()),
@ -1017,7 +1046,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1) snippet = match.group(1)
data = json.loads(snippet) data = json.loads(snippet)
return Attachment( return Attachment(
url=URL(uri="file://" + data["filepath"]), content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"], mime_type=data["mimetype"],
) )

View file

@ -7,6 +7,7 @@
import asyncio import asyncio
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch, OpenAIResponseInputToolFileSearch,
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults, OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -67,7 +69,7 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id tool_call_id = tool_call.id
function = tool_call.function function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {} tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
if not function or not tool_call_id or not function.name: if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number) yield ToolExecutionResult(sequence_number=sequence_number)
@ -84,7 +86,16 @@ class ToolExecutor:
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution # Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) has_error = bool(
error_exc
or (
result
and (
((error_code := getattr(result, "error_code", None)) and error_code > 0)
or getattr(result, "error_message", None)
)
)
)
async for event_result in self._emit_completion_events( async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
): ):
@ -101,7 +112,9 @@ class ToolExecutor:
sequence_number=sequence_number, sequence_number=sequence_number,
final_output_message=output_message, final_output_message=output_message,
final_input_message=input_message, final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None, citation_files=(
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
),
) )
async def _execute_knowledge_search_via_vector_store( async def _execute_knowledge_search_via_vector_store(
@ -188,8 +201,9 @@ class ToolExecutor:
citation_files[file_id] = filename citation_files[file_id] = filename
# Cast to proper InterleavedContent type (list invariance)
return ToolInvocationResult( return ToolInvocationResult(
content=content_items, content=content_items, # type: ignore[arg-type]
metadata={ metadata={
"document_ids": [r.file_id for r in search_results], "document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results],
@ -209,51 +223,60 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start.""" """Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events) # Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server: if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1 sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( yield ToolExecutionResult(
item_id=item_id, stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
output_index=output_index, item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number, sequence_number=sequence_number,
) )
elif function_name == "web_search": elif function_name == "web_search":
sequence_number += 1 sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( yield ToolExecutionResult(
item_id=item_id, stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
output_index=output_index, item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number, sequence_number=sequence_number,
) )
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress( yield ToolExecutionResult(
item_id=item_id, stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
output_index=output_index, item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number, sequence_number=sequence_number,
) )
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event # For web search, emit searching event
if function_name == "web_search": if function_name == "web_search":
sequence_number += 1 sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( yield ToolExecutionResult(
item_id=item_id, stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
output_index=output_index, item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event # For file search, emit searching event
if function_name == "knowledge_search": if function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching( yield ToolExecutionResult(
item_id=item_id, stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
output_index=output_index, item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool( async def _execute_tool(
self, self,
@ -261,7 +284,7 @@ class ToolExecutor:
tool_kwargs: dict, tool_kwargs: dict,
ctx: ChatCompletionContext, ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]: ) -> tuple[Exception | None, Any]:
"""Execute the tool and return error exception and result.""" """Execute the tool and return error exception and result."""
error_exc = None error_exc = None
result = None result = None
@ -284,9 +307,13 @@ class ToolExecutor:
kwargs=tool_kwargs, kwargs=tool_kwargs,
) )
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
response_file_search_tool = next( response_file_search_tool = (
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), next(
None, (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if ctx.response_tools
else None
) )
if response_file_search_tool: if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool # Use vector_stores.search API instead of knowledge_search tool
@ -322,35 +349,34 @@ class ToolExecutor:
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]: ) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution.""" """Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server: if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1 sequence_number += 1
if has_error: if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
else: else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
elif function_name == "web_search": elif function_name == "web_search":
sequence_number += 1 sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
elif function_name == "knowledge_search": elif function_name == "knowledge_search":
sequence_number += 1 sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id, item_id=item_id,
output_index=output_index, output_index=output_index,
sequence_number=sequence_number, sequence_number=sequence_number,
) )
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
async def _build_result_messages( async def _build_result_messages(
self, self,
@ -360,21 +386,18 @@ class ToolExecutor:
tool_kwargs: dict, tool_kwargs: dict,
ctx: ChatCompletionContext, ctx: ChatCompletionContext,
error_exc: Exception | None, error_exc: Exception | None,
result: any, result: Any,
has_error: bool, has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]: ) -> tuple[Any, Any]:
"""Build output and input messages from tool execution results.""" """Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
# Build output message # Build output message
message: Any
if mcp_tool_to_server and function.name in mcp_tool_to_server: if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall( message = OpenAIResponseOutputMessageMCPCall(
id=item_id, id=item_id,
arguments=function.arguments, arguments=function.arguments,
@ -383,10 +406,14 @@ class ToolExecutor:
) )
if error_exc: if error_exc:
message.error = str(error_exc) message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
message.error = f"Error (code {result.error_code}): {result.error_message}" result and getattr(result, "error_message", None)
elif result and result.content: ):
message.output = interleaved_content_as_str(result.content) ec = getattr(result, "error_code", "unknown")
em = getattr(result, "error_message", "")
message.error = f"Error (code {ec}): {em}"
elif result and (content := getattr(result, "content", None)):
message.output = interleaved_content_as_str(content)
else: else:
if function.name == "web_search": if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall( message = OpenAIResponseOutputMessageWebSearchToolCall(
@ -401,17 +428,17 @@ class ToolExecutor:
queries=[tool_kwargs.get("query", "")], queries=[tool_kwargs.get("query", "")],
status="completed", status="completed",
) )
if result and "document_ids" in result.metadata: if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
message.results = [] message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]): for i, doc_id in enumerate(metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None text = metadata["chunks"][i] if "chunks" in metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None score = metadata["scores"][i] if "scores" in metadata else None
message.results.append( message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults( OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id, file_id=doc_id,
filename=doc_id, filename=doc_id,
text=text, text=text if text is not None else "",
score=score, score=score if score is not None else 0.0,
attributes={}, attributes={},
) )
) )
@ -421,27 +448,32 @@ class ToolExecutor:
raise ValueError(f"Unknown tool {function.name} called") raise ValueError(f"Unknown tool {function.name} called")
# Build input message # Build input message
input_message = None input_message: OpenAIToolMessageParam | None = None
if result and result.content: if result and (result_content := getattr(result, "content", None)):
if isinstance(result.content, str): # all the mypy contortions here are still unsatisfactory with random Any typing
content = result.content if isinstance(result_content, str):
elif isinstance(result.content, list): msg_content: str | list[Any] = result_content
content = [] elif isinstance(result_content, list):
for item in result.content: content_list: list[Any] = []
for item in result_content:
part: Any
if isinstance(item, TextContentItem): if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text) part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem): elif isinstance(item, ImageContentItem):
if item.image.data: if item.image.data:
url = f"data:image;base64,{item.image.data}" url_value = f"data:image;base64,{item.image.data}"
else: else:
url = item.image.url url_value = str(item.image.url) if item.image.url else ""
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
else: else:
raise ValueError(f"Unknown result content type: {type(item)}") raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part) content_list.append(part)
msg_content = content_list
else: else:
raise ValueError(f"Unknown result content type: {type(result.content)}") raise ValueError(f"Unknown result content type: {type(result_content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) # OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
# This is runtime-safe as the API accepts it, but mypy complains
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
else: else:
text = str(error_exc) if error_exc else "Tool execution failed" text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)