diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index bd1a07d7c..99fa8219c 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -6,7 +6,7 @@ import json import warnings -from typing import AsyncGenerator, Literal +from typing import AsyncGenerator, Literal, Union from groq import Stream from groq.types.chat.chat_completion import ChatCompletion @@ -30,6 +30,8 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition +from pydantic import BaseModel + from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -150,15 +152,26 @@ def convert_chat_completion_response( _convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls ] - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=tool_calls, - stop_reason=StopReason.end_of_message, - # Content is not optional - content="", - ), - logprobs=None, - ) + if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): + # If we couldn't parse a tool call, jsonify the tool calls and return them + return ChatCompletionResponse( + completion_message=CompletionMessage( + stop_reason=StopReason.end_of_message, + content=json.dumps(tool_calls, default=lambda x: x.model_dump()), + ), + logprobs=None, + ) + else: + # Otherwise, return tool calls as normal + return ChatCompletionResponse( + completion_message=CompletionMessage( + tool_calls=tool_calls, + stop_reason=StopReason.end_of_message, + # Content is not optional + content="", + ), + logprobs=None, + ) else: return ChatCompletionResponse( completion_message=CompletionMessage( @@ -214,15 +227,27 @@ async def convert_chat_completion_response_stream( # We assume Groq produces fully formed tool calls for each chunk tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), + if isinstance(tool_call, ToolCall): + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + else: + # Otherwise it's an UnparseableToolCall - return the raw tool call + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=tool_call.model_dump_json(), + parse_status=ToolCallParseStatus.failed, + ), + ) ) - ) else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -234,12 +259,35 @@ async def convert_chat_completion_response_stream( event_type = ChatCompletionResponseEventType.progress -def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall: +class UnparseableToolCall(BaseModel): + """ + A ToolCall with arguments that are not valid JSON. + Mirrors the ToolCall schema, but with arguments as a string. + """ + + call_id: str + tool_name: str + arguments: str + + +def _convert_groq_tool_call( + tool_call: ChatCompletionMessageToolCall, +) -> Union[ToolCall, UnparseableToolCall]: + """ + Convert a Groq tool call to a ToolCall. + Returns an UnparseableToolCall if the tool call is not valid JSON. + """ + try: + arguments = json.loads(tool_call.function.arguments) + except Exception as e: + return UnparseableToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + return ToolCall( call_id=tool_call.id, tool_name=tool_call.function.name, - # Note that Groq may return a string that is not valid JSON here - # So this may raise a 500 error. Going to leave this as is to see - # how big of an issue this is and what we can do about it. - arguments=json.loads(tool_call.function.arguments), + arguments=arguments, ) diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index f6f593f16..5e0797871 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -23,6 +23,7 @@ from groq.types.chat.chat_completion_message_tool_call import ( from groq.types.shared.function_definition import FunctionDefinition from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy from llama_models.llama3.api.datatypes import ToolParamDefinition +from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponseEventType, @@ -347,6 +348,26 @@ class TestConvertNonStreamChatCompletionResponse: ), ] + def test_converts_unparseable_tool_calls(self): + response = self._dummy_chat_completion_response_with_tool_call() + response.choices[0].message.tool_calls = [ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function=Function( + name="log", + arguments="(number=10, base=2)", + ), + ), + ] + + converted = convert_chat_completion_response(response) + + assert ( + converted.completion_message.content + == '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]' + ) + def _dummy_chat_completion_response(self): return ChatCompletion( id="chatcmpl-123", @@ -478,6 +499,40 @@ class TestConvertStreamChatCompletionResponse: arguments={"origin": "AU", "destination": "LAX"}, ) + @pytest.mark.asyncio + async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self): + def tool_call_stream(): + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.tool_calls = [ + ChoiceDeltaToolCall( + index=0, + type="function", + id="tool_call_id", + function=ChoiceDeltaToolCallFunction( + name="get_flight_info", + arguments="(origin=AU, destination=LAX)", + ), + ), + ] + yield chunk + + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = tool_call_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert ( + chunk.event.delta.content + == '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}' + ) + assert chunk.event.delta.parse_status == ToolCallParseStatus.failed + def _dummy_chat_completion_chunk(self): return ChatCompletionChunk( id="chatcmpl-123",