Summary:

Test Plan:
This commit is contained in:
Eric Huang 2025-02-26 20:44:26 -08:00
parent 270d64007a
commit f389afe024
4 changed files with 250 additions and 166 deletions

View file

@ -656,30 +656,27 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
# Process all tool calls instead of just the first one
tool_responses = []
tool_execution_start_time = datetime.now().astimezone().isoformat()
# Check if any tool is a client tool
client_tool_found = False
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_found = True
break
# If any tool is a client tool, yield CompletionMessage and return
if client_tool_found:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_calls=message.tool_calls,
tool_responses=[],
started_at=datetime.now().astimezone().isoformat(),
),
@ -687,41 +684,86 @@ class ChatAgent(ShieldRunnerMixin):
yield message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
"tool_execution",
{
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now().astimezone().isoformat()
tool_call = message.tool_calls[0]
tool_result = await execute_tool_call_maybe(
self.tool_runtime_api,
session_id,
tool_call,
toolgroup_args,
tool_to_group,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
# Add the original message with tool calls to input_messages before processing tool calls
input_messages.append(message)
# Process all tool calls
for tool_call in message.tool_calls:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
result_messages = [
ToolResponseMessage(
)
# Execute the tool call
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
"tool_execution",
{
"tool_name": tool_name,
"input": tool_call.model_dump_json()
if hasattr(tool_call, "model_dump_json")
else str(tool_call),
},
) as span:
tool_result = await execute_tool_call_maybe(
self.tool_runtime_api,
session_id,
tool_call,
toolgroup_args,
tool_to_group,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
tool_responses.append(
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
)
span.set_attribute(
"output",
result_message.model_dump_json()
if hasattr(result_message, "model_dump_json")
else str(result_message),
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
# Add the result message to input_messages
input_messages.append(result_message)
# Complete the tool execution step
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
@ -730,15 +772,8 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
],
tool_calls=message.tool_calls,
tool_responses=tool_responses,
started_at=tool_execution_start_time,
completed_at=datetime.now().astimezone().isoformat(),
),
@ -746,18 +781,6 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
input_messages = input_messages + [message, result_message]
n_iter += 1
async def _get_tool_defs(

View file

@ -127,7 +127,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
tokens = 0
picked = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
text=f"knowledge_search tool found {len(chunks)} chunks for query:\n{query}\nBEGIN of knowledge_search tool results.\n"
)
]
for i, c in enumerate(chunks):

View file

@ -99,6 +99,10 @@ class LiteLLMOpenAIMixin(
params = await self._get_params(request)
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
from rich.pretty import pprint
pprint(params)
response = litellm.completion(**params)
if stream:
return self._stream_chat_completion(response)

View file

@ -523,10 +523,11 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content,
)
# return OpenAIChatCompletionContentPartTextParam(
# type="text",
# text=content,
# )
return content
elif isinstance(content, TextContentItem):
return OpenAIChatCompletionContentPartTextParam(
type="text",
@ -568,12 +569,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
content=await _convert_user_message_content(message.content),
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
content=await _convert_user_message_content(message.content),
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
@ -831,18 +832,26 @@ async def convert_openai_chat_completion_stream(
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
stop_reason = None
toolcall_buffer = {}
# Use a dictionary to track multiple tool calls by their index
toolcall_buffers = {}
# Track which tool calls have been completed
completed_tool_indices = set()
# Track the highest index seen so far
highest_index_seen = -1
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
async for chunk in stream:
from rich.pretty import pprint
pprint(chunk)
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
@ -851,112 +860,108 @@ async def convert_openai_chat_completion_stream(
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs),
)
)
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
# Process each tool call in the delta
for tool_call in choice.delta.tool_calls:
# Get the tool call index
tool_index = getattr(tool_call, "index", 0)
if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(logprobs),
# If we see a new higher index, complete all previous tool calls
if tool_index > highest_index_seen:
# Complete all previous tool calls
for prev_index in range(highest_index_seen + 1):
if prev_index in toolcall_buffers and prev_index not in completed_tool_indices:
# Complete this tool call
async for event in _complete_tool_call(
toolcall_buffers[prev_index],
logprobs,
None, # No stop_reason for intermediate tool calls
):
yield event
completed_tool_indices.add(prev_index)
highest_index_seen = tool_index
# Skip if this tool call has already been completed
if tool_index in completed_tool_indices:
continue
# Initialize buffer for this tool call if it doesn't exist
if tool_index not in toolcall_buffers:
toolcall_buffers[tool_index] = {
"call_id": tool_call.id,
"name": None,
"content": "",
"arguments": "",
"complete": False,
}
buffer = toolcall_buffers[tool_index]
# Handle function name
if tool_call.function and tool_call.function.name:
buffer["name"] = tool_call.function.name
delta = f"{buffer['name']}("
buffer["content"] += delta
# Emit the function name
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
)
else:
tool_call = choice.delta.tool_calls[0]
if "name" not in toolcall_buffer:
toolcall_buffer["call_id"] = tool_call.id
toolcall_buffer["name"] = None
toolcall_buffer["content"] = ""
if "arguments" not in toolcall_buffer:
toolcall_buffer["arguments"] = ""
if tool_call.function.name:
toolcall_buffer["name"] = tool_call.function.name
delta = f"{toolcall_buffer['name']}("
if tool_call.function.arguments:
toolcall_buffer["arguments"] += tool_call.function.arguments
delta = toolcall_buffer["arguments"]
# Handle function arguments
if tool_call.function and tool_call.function.arguments:
delta = tool_call.function.arguments
buffer["arguments"] += delta
buffer["content"] += delta
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
# Emit the argument fragment
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs),
)
)
if toolcall_buffer:
delta = ")"
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
try:
arguments = json.loads(toolcall_buffer["arguments"])
tool_call = ToolCall(
call_id=toolcall_buffer["call_id"],
tool_name=toolcall_buffer["name"],
arguments=arguments,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
except json.JSONDecodeError:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=ToolCallDelta(
tool_call=toolcall_buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
# Final complete event if no tool calls were processed
if toolcall_buffers:
# Process all tool calls that haven't been completed yet
for tool_index in sorted(toolcall_buffers.keys()):
if tool_index not in completed_tool_indices:
# Complete this tool call
async for event in _complete_tool_call(toolcall_buffers[tool_index], logprobs, stop_reason):
yield event
completed_tool_indices.add(tool_index)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -965,3 +970,55 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
async def _complete_tool_call(buffer, logprobs, stop_reason):
"""Helper function to complete a tool call and yield the appropriate events."""
# Add closing parenthesis
delta = ")"
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
try:
# Parse the arguments
arguments = json.loads(buffer["arguments"])
tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
except json.JSONDecodeError:
print(f"Failed to parse tool call arguments: {buffer['arguments']}")
event_type_to_use = ChatCompletionResponseEventType.complete
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type_to_use,
delta=ToolCallDelta(
tool_call=buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)