From f389afe024f99f49db3091966fdacfbd755f8800 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Wed, 26 Feb 2025 20:44:26 -0800 Subject: [PATCH] temp fix Summary: Test Plan: --- .../agents/meta_reference/agent_instance.py | 157 ++++++----- .../inline/tool_runtime/rag/memory.py | 2 +- .../utils/inference/litellm_openai_mixin.py | 4 + .../utils/inference/openai_compat.py | 253 +++++++++++------- 4 files changed, 250 insertions(+), 166 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b17179463..474c72245 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -656,30 +656,27 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - tool_call = message.tool_calls[0] - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - tool_call=tool_call, - delta=ToolCallDelta( - parse_status=ToolCallParseStatus.in_progress, - tool_call=tool_call, - ), - ) - ) - ) - # If tool is a client tool, yield CompletionMessage and return - if tool_call.tool_name in client_tools: + # Process all tool calls instead of just the first one + tool_responses = [] + tool_execution_start_time = datetime.now().astimezone().isoformat() + + # Check if any tool is a client tool + client_tool_found = False + for tool_call in message.tool_calls: + if tool_call.tool_name in client_tools: + client_tool_found = True + break + + # If any tool is a client tool, yield CompletionMessage and return + if client_tool_found: await self.storage.set_in_progress_tool_call_step( session_id, turn_id, ToolExecutionStep( step_id=step_id, turn_id=turn_id, - tool_calls=[tool_call], + tool_calls=message.tool_calls, tool_responses=[], started_at=datetime.now().astimezone().isoformat(), ), @@ -687,41 +684,86 @@ class ChatAgent(ShieldRunnerMixin): yield message return - # If tool is a builtin server tool, execute it - tool_name = tool_call.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - with tracing.span( - "tool_execution", - { - "tool_name": tool_name, - "input": message.model_dump_json(), - }, - ) as span: - tool_execution_start_time = datetime.now().astimezone().isoformat() - tool_call = message.tool_calls[0] - tool_result = await execute_tool_call_maybe( - self.tool_runtime_api, - session_id, - tool_call, - toolgroup_args, - tool_to_group, - ) - if tool_result.content is None: - raise ValueError( - f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content" + # Add the original message with tool calls to input_messages before processing tool calls + input_messages.append(message) + + # Process all tool calls + for tool_call in message.tool_calls: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + tool_call=tool_call, + delta=ToolCallDelta( + parse_status=ToolCallParseStatus.in_progress, + tool_call=tool_call, + ), + ) ) - result_messages = [ - ToolResponseMessage( + ) + + # Execute the tool call + tool_name = tool_call.tool_name + if isinstance(tool_name, BuiltinTool): + tool_name = tool_name.value + with tracing.span( + "tool_execution", + { + "tool_name": tool_name, + "input": tool_call.model_dump_json() + if hasattr(tool_call, "model_dump_json") + else str(tool_call), + }, + ) as span: + tool_result = await execute_tool_call_maybe( + self.tool_runtime_api, + session_id, + tool_call, + toolgroup_args, + tool_to_group, + ) + if tool_result.content is None: + raise ValueError( + f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content" + ) + + result_message = ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=tool_result.content, ) - ] - assert len(result_messages) == 1, "Currently not supporting multiple messages" - result_message = result_messages[0] - span.set_attribute("output", result_message.model_dump_json()) + tool_responses.append( + ToolResponse( + call_id=result_message.call_id, + tool_name=result_message.tool_name, + content=result_message.content, + metadata=tool_result.metadata, + ) + ) + + span.set_attribute( + "output", + result_message.model_dump_json() + if hasattr(result_message, "model_dump_json") + else str(result_message), + ) + + # TODO: add tool-input touchpoint and a "start" event for this step also + # but that needs a lot more refactoring of Tool code potentially + if (type(result_message.content) is str) and ( + out_attachment := _interpret_content_as_attachment(result_message.content) + ): + # NOTE: when we push this message back to the model, the model may ignore the + # attached file path etc. since the model is trained to only provide a user message + # with the summary. We keep all generated attachments and then attach them to final message + output_attachments.append(out_attachment) + + # Add the result message to input_messages + input_messages.append(result_message) + + # Complete the tool execution step yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( @@ -730,15 +772,8 @@ class ChatAgent(ShieldRunnerMixin): step_details=ToolExecutionStep( step_id=step_id, turn_id=turn_id, - tool_calls=[tool_call], - tool_responses=[ - ToolResponse( - call_id=result_message.call_id, - tool_name=result_message.tool_name, - content=result_message.content, - metadata=tool_result.metadata, - ) - ], + tool_calls=message.tool_calls, + tool_responses=tool_responses, started_at=tool_execution_start_time, completed_at=datetime.now().astimezone().isoformat(), ), @@ -746,18 +781,6 @@ class ChatAgent(ShieldRunnerMixin): ) ) - # TODO: add tool-input touchpoint and a "start" event for this step also - # but that needs a lot more refactoring of Tool code potentially - if (type(result_message.content) is str) and ( - out_attachment := _interpret_content_as_attachment(result_message.content) - ): - # NOTE: when we push this message back to the model, the model may ignore the - # attached file path etc. since the model is trained to only provide a user message - # with the summary. We keep all generated attachments and then attach them to final message - output_attachments.append(out_attachment) - - input_messages = input_messages + [message, result_message] - n_iter += 1 async def _get_tool_defs( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 4b3f7d9e7..801579121 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -127,7 +127,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): tokens = 0 picked = [ TextContentItem( - text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" + text=f"knowledge_search tool found {len(chunks)} chunks for query:\n{query}\nBEGIN of knowledge_search tool results.\n" ) ] for i, c in enumerate(chunks): diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index a916e4f99..58aad4281 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -99,6 +99,10 @@ class LiteLLMOpenAIMixin( params = await self._get_params(request) # unfortunately, we need to use synchronous litellm.completion here because litellm # caches various httpx.client objects in a non-eventloop aware manner + + from rich.pretty import pprint + + pprint(params) response = litellm.completion(**params) if stream: return self._stream_chat_completion(response) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1f1306f0d..e98b25fb8 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -523,10 +523,11 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content, str): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content, - ) + # return OpenAIChatCompletionContentPartTextParam( + # type="text", + # text=content, + # ) + return content elif isinstance(content, TextContentItem): return OpenAIChatCompletionContentPartTextParam( type="text", @@ -568,12 +569,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC out = OpenAIChatCompletionToolMessage( role="tool", tool_call_id=message.call_id, - content=message.content, + content=await _convert_user_message_content(message.content), ) elif isinstance(message, SystemMessage): out = OpenAIChatCompletionSystemMessage( role="system", - content=message.content, + content=await _convert_user_message_content(message.content), ) else: raise ValueError(f"Unsupported message type: {type(message)}") @@ -831,18 +832,26 @@ async def convert_openai_chat_completion_stream( Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. """ - - # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... - def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: - yield ChatCompletionResponseEventType.start - while True: - yield ChatCompletionResponseEventType.progress - - event_type = _event_type_generator() - stop_reason = None - toolcall_buffer = {} + # Use a dictionary to track multiple tool calls by their index + toolcall_buffers = {} + # Track which tool calls have been completed + completed_tool_indices = set() + # Track the highest index seen so far + highest_index_seen = -1 + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + stop_reason=stop_reason, + ) + ) + async for chunk in stream: + from rich.pretty import pprint + + pprint(chunk) choice = chunk.choices[0] # assuming only one choice per chunk # we assume there's only one finish_reason in the stream @@ -851,112 +860,108 @@ async def convert_openai_chat_completion_stream( # if there's a tool call, emit an event for each tool in the list # if tool call and content, emit both separately - if choice.delta.tool_calls: # the call may have content and a tool call. ChatCompletionResponseEvent # does not support both, so we emit the content first if choice.delta.content: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=ChatCompletionResponseEventType.progress, delta=TextDelta(text=choice.delta.content), logprobs=_convert_openai_logprobs(logprobs), ) ) - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest") + # Process each tool call in the delta + for tool_call in choice.delta.tool_calls: + # Get the tool call index + tool_index = getattr(tool_call, "index", 0) - if not enable_incremental_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0], - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(logprobs), + # If we see a new higher index, complete all previous tool calls + if tool_index > highest_index_seen: + # Complete all previous tool calls + for prev_index in range(highest_index_seen + 1): + if prev_index in toolcall_buffers and prev_index not in completed_tool_indices: + # Complete this tool call + async for event in _complete_tool_call( + toolcall_buffers[prev_index], + logprobs, + None, # No stop_reason for intermediate tool calls + ): + yield event + completed_tool_indices.add(prev_index) + + highest_index_seen = tool_index + + # Skip if this tool call has already been completed + if tool_index in completed_tool_indices: + continue + + # Initialize buffer for this tool call if it doesn't exist + if tool_index not in toolcall_buffers: + toolcall_buffers[tool_index] = { + "call_id": tool_call.id, + "name": None, + "content": "", + "arguments": "", + "complete": False, + } + + buffer = toolcall_buffers[tool_index] + + # Handle function name + if tool_call.function and tool_call.function.name: + buffer["name"] = tool_call.function.name + delta = f"{buffer['name']}(" + buffer["content"] += delta + + # Emit the function name + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) ) - ) - else: - tool_call = choice.delta.tool_calls[0] - if "name" not in toolcall_buffer: - toolcall_buffer["call_id"] = tool_call.id - toolcall_buffer["name"] = None - toolcall_buffer["content"] = "" - if "arguments" not in toolcall_buffer: - toolcall_buffer["arguments"] = "" - if tool_call.function.name: - toolcall_buffer["name"] = tool_call.function.name - delta = f"{toolcall_buffer['name']}(" - if tool_call.function.arguments: - toolcall_buffer["arguments"] += tool_call.function.arguments - delta = toolcall_buffer["arguments"] + # Handle function arguments + if tool_call.function and tool_call.function.arguments: + delta = tool_call.function.arguments + buffer["arguments"] += delta + buffer["content"] += delta - toolcall_buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), + # Emit the argument fragment + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) ) - ) else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=next(event_type), + event_type=ChatCompletionResponseEventType.progress, delta=TextDelta(text=choice.delta.content or ""), logprobs=_convert_openai_logprobs(logprobs), ) ) - if toolcall_buffer: - delta = ")" - toolcall_buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), - ) - ) - try: - arguments = json.loads(toolcall_buffer["arguments"]) - tool_call = ToolCall( - call_id=toolcall_buffer["call_id"], - tool_name=toolcall_buffer["name"], - arguments=arguments, - ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - except json.JSONDecodeError: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=ToolCallDelta( - tool_call=toolcall_buffer["content"], - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) + # Final complete event if no tool calls were processed + if toolcall_buffers: + # Process all tool calls that haven't been completed yet + for tool_index in sorted(toolcall_buffers.keys()): + if tool_index not in completed_tool_indices: + # Complete this tool call + async for event in _complete_tool_call(toolcall_buffers[tool_index], logprobs, stop_reason): + yield event + completed_tool_indices.add(tool_index) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -965,3 +970,55 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) + + +async def _complete_tool_call(buffer, logprobs, stop_reason): + """Helper function to complete a tool call and yield the appropriate events.""" + # Add closing parenthesis + delta = ")" + buffer["content"] += delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + + try: + # Parse the arguments + arguments = json.loads(buffer["arguments"]) + tool_call = ToolCall( + call_id=buffer["call_id"], + tool_name=buffer["name"], + arguments=arguments, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + except json.JSONDecodeError: + print(f"Failed to parse tool call arguments: {buffer['arguments']}") + + event_type_to_use = ChatCompletionResponseEventType.complete + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type_to_use, + delta=ToolCallDelta( + tool_call=buffer["content"], + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) + )