diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 03692bcc7..aa27f421c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -891,16 +891,14 @@ class ChatAgent(ShieldRunnerMixin): if memory_tool and code_interpreter_tool: # if both memory and code_interpreter are available, we download the URLs # and attach the data to the last message. - msg = await attachment_message(self.tempdir, url_items) - input_messages.append(msg) + await attachment_message(self.tempdir, url_items, input_messages[-1]) # Since memory is present, add all the data to the memory bank await self.add_to_session_vector_db(session_id, documents) elif code_interpreter_tool: # if only code_interpreter is available, we download the URLs to a tempdir # and attach the path to them as a message to inference with the # assumption that the model invokes the code_interpreter tool with the path - msg = await attachment_message(self.tempdir, url_items) - input_messages.append(msg) + await attachment_message(self.tempdir, url_items, input_messages[-1]) elif memory_tool: # if only memory is available, we load the data from the URLs and content items to the memory bank await self.add_to_session_vector_db(session_id, documents) @@ -967,8 +965,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]: return data -async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage: - content = [] +async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None: + contents = [] for url in urls: uri = url.uri @@ -988,16 +986,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa else: raise ValueError(f"Unsupported URL {url}") - content.append( + contents.append( TextContentItem( text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' ) ) - return ToolResponseMessage( - call_id="", - content=content, - ) + if isinstance(message.content, list): + message.content.extend(contents) + else: + if isinstance(message.content, str): + message.content = [TextContentItem(text=message.content)] + contents + else: + message.content = [message.content] + contents def _interpret_content_as_attachment( diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index d88dc5a9e..f99883990 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -192,7 +192,11 @@ class LiteLLMOpenAIMixin( if request.tools: input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] if request.tool_config.tool_choice: - input_dict["tool_choice"] = request.tool_config.tool_choice.value + input_dict["tool_choice"] = ( + request.tool_config.tool_choice.value + if isinstance(request.tool_config.tool_choice, ToolChoice) + else request.tool_config.tool_choice + ) provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index ac37171c9..2a362f8cb 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -527,26 +527,30 @@ async def convert_message_to_openai_dict_new( async def _convert_message_content( content: InterleavedContent, ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: - async def impl(): + async def impl( + content_: InterleavedContent, + ) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: # Llama Stack and OpenAI spec match for str and text input - if isinstance(content, str): - return content - elif isinstance(content, TextContentItem): + if isinstance(content_, str): + return content_ + elif isinstance(content_, TextContentItem): return OpenAIChatCompletionContentPartTextParam( type="text", - text=content.text, + text=content_.text, ) - elif isinstance(content, ImageContentItem): + elif isinstance(content_, ImageContentItem): return OpenAIChatCompletionContentPartImageParam( type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), + image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)), ) - elif isinstance(content, list): - return [await _convert_message_content(item) for item in content] + elif isinstance(content_, list): + return [await impl(item) for item in content_] else: - raise ValueError(f"Unsupported content type: {type(content)}") + raise ValueError(f"Unsupported content type: {type(content_)}") - ret = await impl() + ret = await impl(content) + + # OpenAI*Message expects a str or list if isinstance(ret, str) or isinstance(ret, list): return ret else: @@ -566,13 +570,14 @@ async def convert_message_to_openai_dict_new( OpenAIChatCompletionMessageToolCall( id=tool.call_id, function=OpenAIFunction( - name=tool.tool_name, + name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, arguments=json.dumps(tool.arguments), ), type="function", ) for tool in message.tool_calls - ], + ] + or None, ) elif isinstance(message, ToolResponseMessage): out = OpenAIChatCompletionToolMessage( @@ -858,7 +863,8 @@ async def convert_openai_chat_completion_stream( event_type = ChatCompletionResponseEventType.progress stop_reason = None - toolcall_buffer = {} + tool_call_idx_to_buffer = {} + async for chunk in stream: choice = chunk.choices[0] # assuming only one choice per chunk @@ -868,7 +874,6 @@ async def convert_openai_chat_completion_stream( # if there's a tool call, emit an event for each tool in the list # if tool call and content, emit both separately - if choice.delta.tool_calls: # the call may have content and a tool call. ChatCompletionResponseEvent # does not support both, so we emit the content first @@ -889,44 +894,53 @@ async def convert_openai_chat_completion_stream( ) if not enable_incremental_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0], - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(logprobs), + for tool_call in choice.delta.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=_convert_openai_tool_calls([tool_call])[0], + parse_status=ToolCallParseStatus.succeeded, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) ) - ) else: - tool_call = choice.delta.tool_calls[0] - if "name" not in toolcall_buffer: - toolcall_buffer["call_id"] = tool_call.id - toolcall_buffer["name"] = None - toolcall_buffer["content"] = "" - if "arguments" not in toolcall_buffer: - toolcall_buffer["arguments"] = "" + for tool_call in choice.delta.tool_calls: + idx = tool_call.index if hasattr(tool_call, "index") else 0 - if tool_call.function.name: - toolcall_buffer["name"] = tool_call.function.name - delta = f"{toolcall_buffer['name']}(" - if tool_call.function.arguments: - toolcall_buffer["arguments"] += tool_call.function.arguments - delta = toolcall_buffer["arguments"] + if idx not in tool_call_idx_to_buffer: + tool_call_idx_to_buffer[idx] = { + "call_id": tool_call.id, + "name": None, + "arguments": "", + "content": "", + } - toolcall_buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), - ) - ) - else: + buffer = tool_call_idx_to_buffer[idx] + + if tool_call.function: + if tool_call.function.name: + buffer["name"] = tool_call.function.name + delta = f"{buffer['name']}(" + buffer["content"] += delta + + if tool_call.function.arguments: + delta = tool_call.function.arguments + buffer["arguments"] += delta + buffer["content"] += delta + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + elif choice.delta.content: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=event_type, @@ -935,47 +949,51 @@ async def convert_openai_chat_completion_stream( ) ) - if toolcall_buffer: - delta = ")" - toolcall_buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), - ) - ) - try: - arguments = json.loads(toolcall_buffer["arguments"]) - tool_call = ToolCall( - call_id=toolcall_buffer["call_id"], - tool_name=toolcall_buffer["name"], - arguments=arguments, - ) + for idx, buffer in tool_call_idx_to_buffer.items(): + logger.debug(f"toolcall_buffer[{idx}]: {buffer}") + if buffer["name"]: + delta = ")" + buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, + event_type=event_type, delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, ), - stop_reason=stop_reason, + logprobs=None, ) ) - except json.JSONDecodeError: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=ToolCallDelta( - tool_call=toolcall_buffer["content"], - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, + + try: + arguments = json.loads(buffer["arguments"]) + tool_call = ToolCall( + call_id=buffer["call_id"], + tool_name=buffer["name"], + arguments=arguments, + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + except json.JSONDecodeError as e: + print(f"Failed to parse arguments: {e}") + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=buffer["content"], + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) ) - ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 61249ad17..ef0e8c05e 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -271,7 +271,7 @@ def test_custom_tool(llama_stack_client_with_mocked_inference, agent_config): client_tool = get_boiling_point agent_config = { **agent_config, - "tools": ["builtin::websearch", client_tool], + "tools": [client_tool], } agent = Agent(llama_stack_client_with_mocked_inference, **agent_config) @@ -320,42 +320,55 @@ def test_custom_tool_infinite_loop(llama_stack_client_with_mocked_inference, age assert num_tool_calls <= 5 -def test_tool_choice(llama_stack_client_with_mocked_inference, agent_config): - def run_agent(tool_choice): - client_tool = get_boiling_point - - test_agent_config = { - **agent_config, - "tool_config": {"tool_choice": tool_choice}, - "tools": [client_tool], - } - - agent = Agent(llama_stack_client_with_mocked_inference, **test_agent_config) - session_id = agent.create_session(f"test-session-{uuid4()}") - - response = agent.create_turn( - messages=[ - { - "role": "user", - "content": "What is the boiling point of polyjuice?", - }, - ], - session_id=session_id, - stream=False, - ) - - return [step for step in response.steps if step.step_type == "tool_execution"] - - tool_execution_steps = run_agent("required") +def test_tool_choice_required(llama_stack_client_with_mocked_inference, agent_config): + tool_execution_steps = run_agent_with_tool_choice( + llama_stack_client_with_mocked_inference, agent_config, "required" + ) assert len(tool_execution_steps) > 0 - tool_execution_steps = run_agent("none") + +def test_tool_choice_none(llama_stack_client_with_mocked_inference, agent_config): + tool_execution_steps = run_agent_with_tool_choice(llama_stack_client_with_mocked_inference, agent_config, "none") assert len(tool_execution_steps) == 0 - tool_execution_steps = run_agent("get_boiling_point") + +def test_tool_choice_get_boiling_point(llama_stack_client_with_mocked_inference, agent_config): + if "llama" not in agent_config["model"].lower(): + pytest.xfail("NotImplemented for non-llama models") + + tool_execution_steps = run_agent_with_tool_choice( + llama_stack_client_with_mocked_inference, agent_config, "get_boiling_point" + ) assert len(tool_execution_steps) >= 1 and tool_execution_steps[0].tool_calls[0].tool_name == "get_boiling_point" +def run_agent_with_tool_choice(client, agent_config, tool_choice): + client_tool = get_boiling_point + + test_agent_config = { + **agent_config, + "tool_config": {"tool_choice": tool_choice}, + "tools": [client_tool], + "max_infer_iters": 2, + } + + agent = Agent(client, **test_agent_config) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + stream=False, + ) + + return [step for step in response.steps if step.step_type == "tool_execution"] + + @pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) def test_rag_agent(llama_stack_client_with_mocked_inference, agent_config, rag_tool_name): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]