mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(mypy-cleanup): part-01 resolve meta reference agent type issues (126 errors) (#3945)
Error fixes in Agents implementation (`meta-reference` provider) -- adding proper type annotations and using type narrowing for optional attributes. Essentially a bunch of `if x and x_foo := getattr(x, "foo")` instead of `x.foo` directly Part of ongoing mypy remediation effort. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
22bf0d0471
commit
ce31aa1704
2 changed files with 204 additions and 143 deletions
|
|
@ -11,6 +11,7 @@ import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
|
|
||||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
tool_call_ids = set()
|
tool_call_ids = set()
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||||
for response in step.tool_responses:
|
for response in step.tool_responses:
|
||||||
tool_call_ids.add(response.call_id)
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
|
|
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.inference.value:
|
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
|
||||||
messages.append(step.model_response)
|
messages.append(step.model_response)
|
||||||
elif step.step_type == StepType.tool_execution.value:
|
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||||
for response in step.tool_responses:
|
for response in step.tool_responses:
|
||||||
messages.append(
|
messages.append(
|
||||||
ToolResponseMessage(
|
ToolResponseMessage(
|
||||||
|
|
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
content=response.content,
|
content=response.content,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif step.step_type == StepType.shield_call.value:
|
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
|
||||||
if step.violation:
|
if step.violation and step.violation.user_message:
|
||||||
# CompletionMessage itself in the ShieldResponse
|
# CompletionMessage itself in the ShieldResponse
|
||||||
messages.append(
|
messages.append(
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
|
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
|
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
steps = []
|
steps = []
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
|
||||||
if is_resume:
|
if is_resume:
|
||||||
|
assert isinstance(request, AgentTurnResumeRequest)
|
||||||
tool_response_messages = [
|
tool_response_messages = [
|
||||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||||
]
|
]
|
||||||
|
|
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
request.session_id, request.turn_id
|
request.session_id, request.turn_id
|
||||||
)
|
)
|
||||||
now = datetime.now(UTC).isoformat()
|
now_dt = datetime.now(UTC)
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
tool_responses=request.tool_responses,
|
tool_responses=request.tool_responses,
|
||||||
completed_at=now,
|
completed_at=now_dt,
|
||||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
|
||||||
)
|
)
|
||||||
steps.append(tool_execution_step)
|
steps.append(tool_execution_step)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=tool_execution_step.step_id,
|
step_id=tool_execution_step.step_id,
|
||||||
step_details=tool_execution_step,
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
input_messages = last_turn.input_messages
|
# Cast needed due to list invariance - last_turn.input_messages is the right type
|
||||||
|
input_messages = last_turn.input_messages # type: ignore[assignment]
|
||||||
|
|
||||||
turn_id = request.turn_id
|
actual_turn_id = request.turn_id
|
||||||
start_time = last_turn.started_at
|
start_time = last_turn.started_at
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(request, AgentTurnCreateRequest)
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
start_time = datetime.now(UTC).isoformat()
|
start_time = datetime.now(UTC)
|
||||||
input_messages = request.messages
|
# Cast needed due to list invariance - request.messages is the right type
|
||||||
|
input_messages = request.messages # type: ignore[assignment]
|
||||||
|
# Use the generated turn_id from beginning of function
|
||||||
|
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
|
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
|
||||||
|
req_sampling = (
|
||||||
|
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=actual_turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=req_sampling,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
documents=request.documents if not is_resume else None,
|
documents=req_documents,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
|
|
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
|
||||||
steps.append(event.payload.step_details)
|
event.payload, "step_details"
|
||||||
|
):
|
||||||
|
step_details = event.payload.step_details
|
||||||
|
steps.append(step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
assert output_message is not None
|
assert output_message is not None
|
||||||
|
|
||||||
turn = Turn(
|
turn = Turn(
|
||||||
turn_id=turn_id,
|
turn_id=actual_turn_id,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
input_messages=input_messages,
|
input_messages=input_messages, # type: ignore[arg-type]
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
@ -345,7 +361,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||||
|
|
||||||
if len(self.input_shields) > 0:
|
if self.input_shields:
|
||||||
async for res in self.run_multiple_shields_wrapper(
|
async for res in self.run_multiple_shields_wrapper(
|
||||||
turn_id, input_messages, self.input_shields, "user-input"
|
turn_id, input_messages, self.input_shields, "user-input"
|
||||||
):
|
):
|
||||||
|
|
@ -374,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# for output shields run on the full input and output combination
|
# for output shields run on the full input and output combination
|
||||||
messages = input_messages + [final_response]
|
messages = input_messages + [final_response]
|
||||||
|
|
||||||
if len(self.output_shields) > 0:
|
if self.output_shields:
|
||||||
async for res in self.run_multiple_shields_wrapper(
|
async for res in self.run_multiple_shields_wrapper(
|
||||||
turn_id, messages, self.output_shields, "assistant-output"
|
turn_id, messages, self.output_shields, "assistant-output"
|
||||||
):
|
):
|
||||||
|
|
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
shield_call_start_time = datetime.now(UTC)
|
||||||
try:
|
try:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
metadata=dict(touchpoint=touchpoint),
|
metadata=dict(touchpoint=touchpoint),
|
||||||
)
|
)
|
||||||
|
|
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=e.violation,
|
violation=e.violation,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=None,
|
violation=None,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments: list[Attachment] = []
|
||||||
|
|
||||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
for tool in self.agent_config.client_tools:
|
if self.agent_config.client_tools:
|
||||||
client_tools[tool.name] = tool
|
for tool in self.agent_config.client_tools:
|
||||||
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
inference_start_time = datetime.now(UTC).isoformat()
|
inference_start_time = datetime.now(UTC)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
|
||||||
# Serialize any nested Pydantic models to plain dicts
|
# Serialize any nested Pydantic models to plain dicts
|
||||||
openai_msg = _serialize_nested(openai_msg)
|
openai_msg = _serialize_nested(openai_msg)
|
||||||
|
|
||||||
|
|
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
tools=openai_tools if openai_tools else None,
|
tools=openai_tools if openai_tools else None,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format, # type: ignore[arg-type]
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# Convert OpenAI stream back to Llama Stack format
|
# Convert OpenAI stream back to Llama Stack format
|
||||||
response_stream = convert_openai_chat_completion_stream(
|
response_stream = convert_openai_chat_completion_stream(
|
||||||
openai_stream, enable_incremental_tool_calls=True
|
openai_stream, # type: ignore[arg-type]
|
||||||
|
enable_incremental_tool_calls=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in response_stream:
|
async for chunk in response_stream:
|
||||||
|
|
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
|
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
|
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_attr = json.dumps(
|
output_attr = json.dumps(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
"tool_calls": [
|
||||||
|
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
span.set_attribute("output", output_attr)
|
span.set_attribute("output", output_attr)
|
||||||
|
|
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
|
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
|
||||||
|
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
|
||||||
message = CompletionMessage(
|
message = CompletionMessage(
|
||||||
content=content,
|
content=content,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
tool_calls=tool_calls,
|
tool_calls=valid_tool_calls if valid_tool_calls else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=InferenceStep(
|
step_details=InferenceStep(
|
||||||
# somewhere deep, we are re-assigning message or closing over some
|
# somewhere deep, we are re-assigning message or closing over some
|
||||||
|
|
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
model_response=copy.deepcopy(message),
|
model_response=copy.deepcopy(message),
|
||||||
started_at=inference_start_time,
|
started_at=inference_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
|
||||||
|
if n_iter >= max_iters:
|
||||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
# Do not continue the tool call loop after this point
|
# Do not continue the tool call loop after this point
|
||||||
|
|
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(message.tool_calls) == 0:
|
if not message.tool_calls or len(message.tool_calls) == 0:
|
||||||
if stop_reason == StopReason.end_of_turn:
|
if stop_reason == StopReason.end_of_turn:
|
||||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
if len(output_attachments) > 0:
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += output_attachments
|
# List invariance - attachments are compatible at runtime
|
||||||
|
message.content += output_attachments # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
message.content = [message.content] + output_attachments
|
# List invariance - attachments are compatible at runtime
|
||||||
|
message.content = [message.content] + output_attachments # type: ignore[assignment]
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||||
|
|
@ -725,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
non_client_tool_calls = []
|
non_client_tool_calls = []
|
||||||
|
|
||||||
# Separate client and non-client tool calls
|
# Separate client and non-client tool calls
|
||||||
for tool_call in message.tool_calls:
|
if message.tool_calls:
|
||||||
if tool_call.tool_name in client_tools:
|
for tool_call in message.tool_calls:
|
||||||
client_tool_calls.append(tool_call)
|
if tool_call.tool_name in client_tools:
|
||||||
else:
|
client_tool_calls.append(tool_call)
|
||||||
non_client_tool_calls.append(tool_call)
|
else:
|
||||||
|
non_client_tool_calls.append(tool_call)
|
||||||
|
|
||||||
# Process non-client tool calls first
|
# Process non-client tool calls first
|
||||||
for tool_call in non_client_tool_calls:
|
for tool_call in non_client_tool_calls:
|
||||||
|
|
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
|
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if self.telemetry_enabled
|
if self.telemetry_enabled
|
||||||
else {},
|
else {},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
tool_execution_start_time = datetime.now(UTC)
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
|
|
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield the step completion event
|
# Yield the step completion event
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=tool_execution_step,
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
|
|
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=client_tool_calls,
|
tool_calls=client_tool_calls,
|
||||||
tool_responses=[],
|
tool_responses=[],
|
||||||
started_at=datetime.now(UTC).isoformat(),
|
started_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -868,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
toolgroup_to_args = toolgroup_to_args or {}
|
toolgroup_to_args = toolgroup_to_args or {}
|
||||||
|
|
||||||
tool_name_to_def = {}
|
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||||
tool_name_to_args = {}
|
tool_name_to_args = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
if self.agent_config.client_tools:
|
||||||
if tool_name_to_def.get(tool_def.name, None):
|
for tool_def in self.agent_config.client_tools:
|
||||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
|
|
||||||
# Use input_schema from ToolDef directly
|
# Use input_schema from ToolDef directly
|
||||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||||
tool_name=tool_def.name,
|
tool_name=tool_def.name,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
input_schema=tool_def.input_schema,
|
input_schema=tool_def.input_schema,
|
||||||
)
|
)
|
||||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
|
|
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
identifier = None
|
identifier = None
|
||||||
|
|
||||||
if tool_name_to_def.get(identifier, None):
|
|
||||||
raise ValueError(f"Tool {identifier} already exists")
|
|
||||||
if identifier:
|
if identifier:
|
||||||
tool_name_to_def[identifier] = ToolDefinition(
|
# Convert BuiltinTool to string for dictionary key
|
||||||
tool_name=identifier,
|
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
|
||||||
|
if tool_name_to_def.get(identifier_str, None):
|
||||||
|
raise ValueError(f"Tool {identifier_str} already exists")
|
||||||
|
tool_name_to_def[identifier_str] = ToolDefinition(
|
||||||
|
tool_name=identifier_str,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
input_schema=tool_def.input_schema,
|
input_schema=tool_def.input_schema,
|
||||||
)
|
)
|
||||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
|
||||||
|
|
||||||
self.tool_defs, self.tool_name_to_args = (
|
self.tool_defs, self.tool_name_to_args = (
|
||||||
list(tool_name_to_def.values()),
|
list(tool_name_to_def.values()),
|
||||||
|
|
@ -1017,7 +1046,7 @@ def _interpret_content_as_attachment(
|
||||||
snippet = match.group(1)
|
snippet = match.group(1)
|
||||||
data = json.loads(snippet)
|
data = json.loads(snippet)
|
||||||
return Attachment(
|
return Attachment(
|
||||||
url=URL(uri="file://" + data["filepath"]),
|
content=URL(uri="file://" + data["filepath"]),
|
||||||
mime_type=data["mimetype"],
|
mime_type=data["mimetype"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputToolFileSearch,
|
OpenAIResponseInputToolFileSearch,
|
||||||
|
|
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
|
@ -67,7 +69,7 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function = tool_call.function
|
function = tool_call.function
|
||||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
||||||
|
|
||||||
if not function or not tool_call_id or not function.name:
|
if not function or not tool_call_id or not function.name:
|
||||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||||
|
|
@ -84,7 +86,16 @@ class ToolExecutor:
|
||||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||||
|
|
||||||
# Emit completion events for tool execution
|
# Emit completion events for tool execution
|
||||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
has_error = bool(
|
||||||
|
error_exc
|
||||||
|
or (
|
||||||
|
result
|
||||||
|
and (
|
||||||
|
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
||||||
|
or getattr(result, "error_message", None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
async for event_result in self._emit_completion_events(
|
async for event_result in self._emit_completion_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||||
):
|
):
|
||||||
|
|
@ -101,7 +112,9 @@ class ToolExecutor:
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
final_output_message=output_message,
|
final_output_message=output_message,
|
||||||
final_input_message=input_message,
|
final_input_message=input_message,
|
||||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
citation_files=(
|
||||||
|
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_knowledge_search_via_vector_store(
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
|
@ -188,8 +201,9 @@ class ToolExecutor:
|
||||||
|
|
||||||
citation_files[file_id] = filename
|
citation_files[file_id] = filename
|
||||||
|
|
||||||
|
# Cast to proper InterleavedContent type (list invariance)
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=content_items,
|
content=content_items, # type: ignore[arg-type]
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
|
|
@ -209,51 +223,60 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit progress events for tool execution start."""
|
"""Emit progress events for tool execution start."""
|
||||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||||
progress_event = None
|
|
||||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
yield ToolExecutionResult(
|
||||||
item_id=item_id,
|
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||||
output_index=output_index,
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
yield ToolExecutionResult(
|
||||||
item_id=item_id,
|
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||||
output_index=output_index,
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
yield ToolExecutionResult(
|
||||||
item_id=item_id,
|
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||||
output_index=output_index,
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
if progress_event:
|
|
||||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
# For web search, emit searching event
|
# For web search, emit searching event
|
||||||
if function_name == "web_search":
|
if function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
yield ToolExecutionResult(
|
||||||
item_id=item_id,
|
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||||
output_index=output_index,
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
# For file search, emit searching event
|
# For file search, emit searching event
|
||||||
if function_name == "knowledge_search":
|
if function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
yield ToolExecutionResult(
|
||||||
item_id=item_id,
|
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||||
output_index=output_index,
|
item_id=item_id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
async def _execute_tool(
|
async def _execute_tool(
|
||||||
self,
|
self,
|
||||||
|
|
@ -261,7 +284,7 @@ class ToolExecutor:
|
||||||
tool_kwargs: dict,
|
tool_kwargs: dict,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[Exception | None, any]:
|
) -> tuple[Exception | None, Any]:
|
||||||
"""Execute the tool and return error exception and result."""
|
"""Execute the tool and return error exception and result."""
|
||||||
error_exc = None
|
error_exc = None
|
||||||
result = None
|
result = None
|
||||||
|
|
@ -284,9 +307,13 @@ class ToolExecutor:
|
||||||
kwargs=tool_kwargs,
|
kwargs=tool_kwargs,
|
||||||
)
|
)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
response_file_search_tool = next(
|
response_file_search_tool = (
|
||||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
next(
|
||||||
None,
|
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if ctx.response_tools
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if response_file_search_tool:
|
if response_file_search_tool:
|
||||||
# Use vector_stores.search API instead of knowledge_search tool
|
# Use vector_stores.search API instead of knowledge_search tool
|
||||||
|
|
@ -322,35 +349,34 @@ class ToolExecutor:
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit completion or failure events for tool execution."""
|
"""Emit completion or failure events for tool execution."""
|
||||||
completion_event = None
|
|
||||||
|
|
||||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
if has_error:
|
if has_error:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
||||||
else:
|
else:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
||||||
if completion_event:
|
|
||||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
async def _build_result_messages(
|
async def _build_result_messages(
|
||||||
self,
|
self,
|
||||||
|
|
@ -360,21 +386,18 @@ class ToolExecutor:
|
||||||
tool_kwargs: dict,
|
tool_kwargs: dict,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
error_exc: Exception | None,
|
error_exc: Exception | None,
|
||||||
result: any,
|
result: Any,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[any, any]:
|
) -> tuple[Any, Any]:
|
||||||
"""Build output and input messages from tool execution results."""
|
"""Build output and input messages from tool execution results."""
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build output message
|
# Build output message
|
||||||
|
message: Any
|
||||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
message = OpenAIResponseOutputMessageMCPCall(
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
id=item_id,
|
id=item_id,
|
||||||
arguments=function.arguments,
|
arguments=function.arguments,
|
||||||
|
|
@ -383,10 +406,14 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
if error_exc:
|
if error_exc:
|
||||||
message.error = str(error_exc)
|
message.error = str(error_exc)
|
||||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
|
||||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
result and getattr(result, "error_message", None)
|
||||||
elif result and result.content:
|
):
|
||||||
message.output = interleaved_content_as_str(result.content)
|
ec = getattr(result, "error_code", "unknown")
|
||||||
|
em = getattr(result, "error_message", "")
|
||||||
|
message.error = f"Error (code {ec}): {em}"
|
||||||
|
elif result and (content := getattr(result, "content", None)):
|
||||||
|
message.output = interleaved_content_as_str(content)
|
||||||
else:
|
else:
|
||||||
if function.name == "web_search":
|
if function.name == "web_search":
|
||||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
|
@ -401,17 +428,17 @@ class ToolExecutor:
|
||||||
queries=[tool_kwargs.get("query", "")],
|
queries=[tool_kwargs.get("query", "")],
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
if result and "document_ids" in result.metadata:
|
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
||||||
message.results = []
|
message.results = []
|
||||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
score = metadata["scores"][i] if "scores" in metadata else None
|
||||||
message.results.append(
|
message.results.append(
|
||||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||||
file_id=doc_id,
|
file_id=doc_id,
|
||||||
filename=doc_id,
|
filename=doc_id,
|
||||||
text=text,
|
text=text if text is not None else "",
|
||||||
score=score,
|
score=score if score is not None else 0.0,
|
||||||
attributes={},
|
attributes={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -421,27 +448,32 @@ class ToolExecutor:
|
||||||
raise ValueError(f"Unknown tool {function.name} called")
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
# Build input message
|
# Build input message
|
||||||
input_message = None
|
input_message: OpenAIToolMessageParam | None = None
|
||||||
if result and result.content:
|
if result and (result_content := getattr(result, "content", None)):
|
||||||
if isinstance(result.content, str):
|
# all the mypy contortions here are still unsatisfactory with random Any typing
|
||||||
content = result.content
|
if isinstance(result_content, str):
|
||||||
elif isinstance(result.content, list):
|
msg_content: str | list[Any] = result_content
|
||||||
content = []
|
elif isinstance(result_content, list):
|
||||||
for item in result.content:
|
content_list: list[Any] = []
|
||||||
|
for item in result_content:
|
||||||
|
part: Any
|
||||||
if isinstance(item, TextContentItem):
|
if isinstance(item, TextContentItem):
|
||||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
elif isinstance(item, ImageContentItem):
|
elif isinstance(item, ImageContentItem):
|
||||||
if item.image.data:
|
if item.image.data:
|
||||||
url = f"data:image;base64,{item.image.data}"
|
url_value = f"data:image;base64,{item.image.data}"
|
||||||
else:
|
else:
|
||||||
url = item.image.url
|
url_value = str(item.image.url) if item.image.url else ""
|
||||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
content.append(part)
|
content_list.append(part)
|
||||||
|
msg_content = content_list
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
raise ValueError(f"Unknown result content type: {type(result_content)}")
|
||||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
|
||||||
|
# This is runtime-safe as the API accepts it, but mypy complains
|
||||||
|
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue