diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 51691c546..2f397f438 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin): if delta.type == "tool_call": if delta.parse_status == ToolCallParseStatus.succeeded: tool_calls.append(delta.tool_call) + elif delta.parse_status == ToolCallParseStatus.failed: + # If we cannot parse the tools, set the content to the unparsed raw text + content = delta.tool_call if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 44666fa70..d978cb02e 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -244,7 +244,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 066fda2c1..01e0cbb6d 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -251,7 +251,9 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], + formatter: ChatFormat, + request: ChatCompletionRequest, ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -334,7 +336,6 @@ async def process_chat_completion_stream_response( # parse tool calls and report errors message = formatter.decode_assistant_message_from_content(buffer, stop_reason) - print(f"Parse TOOL CALLS message: {message}") parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: @@ -349,17 +350,33 @@ async def process_chat_completion_stream_response( ) ) + request_tools = {t.tool_name: t for t in request.tools} for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, + if tool_call.tool_name in request_tools: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + else: + logger.warning(f"Tool {tool_call.tool_name} not found in request tools") + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + # Parsing tool call failed due to tool call not being found in request tools, + # We still add the raw message text inside tool_call for responding back to the user + tool_call=buffer, + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) ) - ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 81b476218..206629602 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in "question,expected", [ ("Which planet do humans live on?", "Earth"), - ("Which planet has rings around it with a name starting with letter S?", "Saturn"), + ( + "Which planet has rings around it with a name starting with letter S?", + "Saturn", + ), ], ) def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): @@ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i assert answer.last_name == "Jordan" assert answer.year_of_birth == 1963 assert answer.num_seasons_in_nba == 15 + + +@pytest.mark.parametrize( + "streaming", + [ + True, + False, + ], +) +def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): + # TODO: more dynamic lookup on tool_prompt_format for model family + tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" + request = { + "model_id": text_model_id, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "What pods are in the namespace openshift-lightspeed?", + }, + { + "role": "assistant", + "content": "", + "stop_reason": "end_of_turn", + "tool_calls": [ + { + "call_id": "1", + "tool_name": "get_object_namespace_list", + "arguments": { + "kind": "pod", + "namespace": "openshift-lightspeed", + }, + } + ], + }, + { + "role": "tool", + "call_id": "1", + "tool_name": "get_object_namespace_list", + "content": "the objects are pod1, pod2, pod3", + }, + ], + "tools": [ + { + "tool_name": "get_object_namespace_list", + "description": "Get the list of objects in a namespace", + "parameters": { + "kind": { + "param_type": "string", + "description": "the type of object", + "required": True, + }, + "namespace": { + "param_type": "string", + "description": "the name of the namespace", + "required": True, + }, + }, + } + ], + "tool_choice": "auto", + "tool_prompt_format": tool_prompt_format, + "stream": streaming, + } + + response = llama_stack_client.inference.chat_completion(**request) + + if streaming: + for chunk in response: + delta = chunk.event.delta + if delta.type == "tool_call" and delta.parse_status == "succeeded": + assert delta.tool_call.tool_name == "get_object_namespace_list" + if delta.type == "tool_call" and delta.parse_status == "failed": + # expect raw message that failed to parse in tool_call + assert type(delta.tool_call) == str + assert len(delta.tool_call) > 0 + else: + for tc in response.completion_message.tool_calls: + assert tc.tool_name == "get_object_namespace_list"