From cf87262e9c294c123d92660b3ee26ba386e5938f Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Sat, 14 Dec 2024 22:20:54 +1100 Subject: [PATCH] Add tool calls to groq inference adapter --- .../providers/remote/inference/groq/groq.py | 10 +- .../remote/inference/groq/groq_utils.py | 144 +++++++-- .../tests/inference/groq/test_groq_utils.py | 283 ++++++++++++++++-- .../tests/inference/test_text_inference.py | 23 +- 4 files changed, 400 insertions(+), 60 deletions(-) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4723fc31f..80b05e4d3 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -7,6 +7,7 @@ import warnings from typing import AsyncIterator, List, Optional, Union +import groq from groq import Groq from llama_models.datatypes import SamplingParams from llama_models.llama3.api.datatypes import ( @@ -126,7 +127,14 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper): ) ) - response = self._client.chat.completions.create(**request) + try: + response = self._client.chat.completions.create(**request) + except groq.BadRequestError as e: + if e.body.get("error", {}).get("code") == "tool_use_failed": + # For smaller models, Groq may fail to call a tool even when the request is well formed + raise ValueError("Groq failed to call a tool", e.body.get("error", {})) + else: + raise e if stream: return convert_chat_completion_response_stream(response) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 2aabb80ba..98ecbe2f2 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -6,7 +6,7 @@ import warnings from typing import AsyncGenerator, Generator, Literal - +import json from groq import Stream from groq.types.chat.chat_completion import ChatCompletion from groq.types.chat.chat_completion_assistant_message_param import ( @@ -20,9 +20,13 @@ from groq.types.chat.chat_completion_system_message_param import ( from groq.types.chat.chat_completion_user_message_param import ( ChatCompletionUserMessageParam, ) - from groq.types.chat.completion_create_params import CompletionCreateParams - +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) +from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from groq.types.shared.function_definition import FunctionDefinition +from groq.types.shared.function_parameters import FunctionParameters from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -33,9 +37,14 @@ from llama_stack.apis.inference import ( Message, Role, StopReason, + ToolCall, + ToolDefinition, + ToolParamDefinition, + ToolCallParseStatus, + ToolCallDelta, + ToolPromptFormat, ) - def convert_chat_completion_request( request: ChatCompletionRequest, ) -> CompletionCreateParams: @@ -60,8 +69,8 @@ def convert_chat_completion_request( # so we exclude it for now warnings.warn("repetition_penalty is not supported") - if request.tools: - warnings.warn("tools are not supported yet") + if request.tool_prompt_format != ToolPromptFormat.json: + warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") return CompletionCreateParams( model=request.model, @@ -72,9 +81,10 @@ def convert_chat_completion_request( max_tokens=request.sampling_params.max_tokens or None, temperature=request.sampling_params.temperature, top_p=request.sampling_params.top_p, + tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], + tool_choice=request.tool_choice.value if request.tool_choice else None, ) - def _convert_message(message: Message) -> ChatCompletionMessageParam: if message.role == Role.system.value: return ChatCompletionSystemMessageParam(role="system", content=message.content) @@ -88,17 +98,67 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam: raise ValueError(f"Invalid message role: {message.role}") + +def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict: + # Groq requires a description for function tools + if tool_definition.description is None: + raise AssertionError("tool_definition.description is required") + + tool_parameters = tool_definition.parameters or {} + return ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=tool_definition.tool_name, + description=tool_definition.description, + parameters={ + key: _convert_groq_tool_parameter(param) + for key, param in tool_parameters.items() + }, + ) + ) + + +def _convert_groq_tool_parameter( + tool_parameter: ToolParamDefinition +) -> dict: + param = { + "type": tool_parameter.param_type, + } + if tool_parameter.description is not None: + param["description"] = tool_parameter.description + if tool_parameter.required is not None: + param["required"] = tool_parameter.required + if tool_parameter.default is not None: + param["default"] = tool_parameter.default + return param + + def convert_chat_completion_response( response: ChatCompletion, ) -> ChatCompletionResponse: # groq only supports n=1 at time of writing, so there is only one choice choice = response.choices[0] - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content, - stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), - ), - ) + if choice.finish_reason == "tool_calls": + tool_calls = [ + _convert_groq_tool_call(tool_call) + for tool_call in choice.message.tool_calls + ] + return ChatCompletionResponse( + completion_message=CompletionMessage( + tool_calls=tool_calls, + stop_reason=StopReason.end_of_message, + # Content is not optional + content="", + ), + logprobs=None, + ) + else: + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ), + ) def _map_finish_reason_to_stop_reason( @@ -117,7 +177,7 @@ def _map_finish_reason_to_stop_reason( elif finish_reason == "length": return StopReason.out_of_tokens elif finish_reason == "tool_calls": - raise NotImplementedError("tool_calls is not supported yet") + return StopReason.end_of_message else: raise ValueError(f"Invalid finish reason: {finish_reason}") @@ -139,24 +199,46 @@ async def convert_chat_completion_response_stream( for chunk in stream: choice = chunk.choices[0] - # We assume there's only one finish_reason for the entire stream. - # We collect the last finish_reason if choice.finish_reason: - stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_types), - delta=choice.delta.content or "", - logprobs=None, + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=choice.delta.content or "", + logprobs=None, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ) ) - ) + elif choice.delta.tool_calls: + # We assume there is only one tool call per chunk, but emit a warning in case we're wrong + if len(choice.delta.tool_calls) > 1: + warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - logprobs=None, - stop_reason=stop_reason, - ) + # We assume Groq produces fully formed tool calls for each chunk + tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_types), + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + ) + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_types), + delta=choice.delta.content or "", + logprobs=None, + ) + ) + +def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall: + return ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + # Note that Groq may return a string that is not valid JSON here + # So this may raise a 500 error. Going to leave this as is to see + # how big of an issue this is and what we can do about it. + arguments=json.loads(tool_call.function.arguments), ) diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 53b5c29cb..b9191a1ba 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import pytest +import json from groq.types.chat.chat_completion import ChatCompletion, Choice from groq.types.chat.chat_completion_chunk import ( ChatCompletionChunk, @@ -12,7 +13,15 @@ from groq.types.chat.chat_completion_chunk import ( ChoiceDelta, ) from groq.types.chat.chat_completion_message import ChatCompletionMessage - +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) +from groq.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from groq.types.shared.function_definition import FunctionDefinition from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponseEventType, @@ -20,6 +29,10 @@ from llama_stack.apis.inference import ( StopReason, SystemMessage, UserMessage, + ToolChoice, + ToolDefinition, + ToolParamDefinition, + ToolCall, ) from llama_stack.providers.remote.inference.groq.groq_utils import ( convert_chat_completion_request, @@ -140,12 +153,6 @@ class TestConvertChatCompletionRequest: assert converted["max_tokens"] == 100 - def _dummy_chat_completion_request(self): - return ChatCompletionRequest( - model="Llama-3.2-3B", - messages=[UserMessage(content="Hello World")], - ) - def test_includes_temperature(self): request = self._dummy_chat_completion_request() request.sampling_params.temperature = 0.5 @@ -162,6 +169,112 @@ class TestConvertChatCompletionRequest: assert converted["top_p"] == 0.95 + def test_includes_tool_choice(self): + request = self._dummy_chat_completion_request() + request.tool_choice = ToolChoice.required + + converted = convert_chat_completion_request(request) + + assert converted["tool_choice"] == "required" + + def test_includes_tools(self): + request = self._dummy_chat_completion_request() + request.tools = [ + ToolDefinition( + tool_name="get_flight_info", + description="Get fight information between two destinations.", + parameters={ + "origin": ToolParamDefinition( + param_type="string", + description="The origin airport code. E.g., AU", + required=True, + ), + "destination": ToolParamDefinition( + param_type="string", + description="The destination airport code. E.g., 'LAX'", + required=True, + ), + "passengers": ToolParamDefinition( + param_type="array", + description="The passengers", + required=False, + ), + }, + ), + ToolDefinition( + tool_name="log", + description="Calulate the logarithm of a number", + parameters={ + "number": ToolParamDefinition( + param_type="float", + description="The number to calculate the logarithm of", + required=True, + ), + "base": ToolParamDefinition( + param_type="integer", + description="The base of the logarithm", + required=False, + default=10, + ), + }, + ), + ] + + converted = convert_chat_completion_request(request) + + assert converted["tools"] == [ + { + "type": "function", + "function": FunctionDefinition( + name="get_flight_info", + description="Get fight information between two destinations.", + parameters={ + "origin": { + "type": "string", + "description": "The origin airport code. E.g., AU", + "required": True, + }, + "destination": { + "type": "string", + "description": "The destination airport code. E.g., 'LAX'", + "required": True, + }, + "passengers": { + "type": "array", + "description": "The passengers", + "required": False, + }, + }, + ), + }, + { + "type": "function", + "function": FunctionDefinition( + name="log", + description="Calulate the logarithm of a number", + parameters={ + "number": { + "type": "float", + "description": "The number to calculate the logarithm of", + "required": True, + }, + "base": { + "type": "integer", + "description": "The base of the logarithm", + "required": False, + "default": 10, + }, + }, + ), + }, + ] + + def _dummy_chat_completion_request(self): + return ChatCompletionRequest( + model="Llama-3.2-3B", + messages=[UserMessage(content="Hello World")], + ) + class TestConvertNonStreamChatCompletionResponse: def test_returns_response(self): @@ -188,6 +301,49 @@ class TestConvertNonStreamChatCompletionResponse: assert converted.completion_message.stop_reason == StopReason.out_of_tokens + def test_maps_tool_call_to_end_of_message(self): + response = self._dummy_chat_completion_response_with_tool_call() + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_message + + def test_converts_multiple_tool_calls(self): + response = self._dummy_chat_completion_response_with_tool_call() + response.choices[0].message.tool_calls = [ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function=Function( + name="get_flight_info", + arguments='{"origin": "AU", "destination": "LAX"}', + ), + ), + ChatCompletionMessageToolCall( + id="tool_call_id_2", + type="function", + function=Function( + name="log", + arguments='{"number": 10, "base": 2}', + ), + ), + ] + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.tool_calls == [ + ToolCall( + call_id="tool_call_id", + tool_name="get_flight_info", + arguments={"origin": "AU", "destination": "LAX"}, + ), + ToolCall( + call_id="tool_call_id_2", + tool_name="log", + arguments={"number": 10, "base": 2}, + ), + ] + def _dummy_chat_completion_response(self): return ChatCompletion( id="chatcmpl-123", @@ -205,6 +361,33 @@ class TestConvertNonStreamChatCompletionResponse: object="chat.completion", ) + def _dummy_chat_completion_response_with_tool_call(self): + return ChatCompletion( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function=Function( + name="get_flight_info", + arguments='{"origin": "AU", "destination": "LAX"}', + ), + ) + ], + ), + finish_reason="tool_calls", + ) + ], + created=1729382400, + object="chat.completion", + ) + class TestConvertStreamChatCompletionResponse: @pytest.mark.asyncio @@ -214,10 +397,6 @@ class TestConvertStreamChatCompletionResponse: for i, message in enumerate(messages): chunk = self._dummy_chat_completion_chunk() chunk.choices[0].delta.content = message - if i == len(messages) - 1: - chunk.choices[0].finish_reason = "stop" - else: - chunk.choices[0].finish_reason = None yield chunk chunk = self._dummy_chat_completion_chunk() @@ -241,12 +420,6 @@ class TestConvertStreamChatCompletionResponse: assert chunk.event.event_type == ChatCompletionResponseEventType.progress assert chunk.event.delta == " !" - # Dummy chunk to ensure the last chunk is really the end of the stream - # This one technically maps to Groq's final "stop" chunk - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.progress - assert chunk.event.delta == "" - chunk = await iter.__anext__() assert chunk.event.event_type == ChatCompletionResponseEventType.complete assert chunk.event.delta == "" @@ -255,6 +428,54 @@ class TestConvertStreamChatCompletionResponse: with pytest.raises(StopAsyncIteration): await iter.__anext__() + @pytest.mark.asyncio + async def test_returns_tool_calls_stream(self): + def tool_call_stream(): + tool_calls = [ + ToolCall( + call_id="tool_call_id", + tool_name="get_flight_info", + arguments={"origin": "AU", "destination": "LAX"}, + ), + ToolCall( + call_id="tool_call_id_2", + tool_name="log", + arguments={"number": 10, "base": 2}, + ), + ] + for i, tool_call in enumerate(tool_calls): + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.tool_calls = [ + ChoiceDeltaToolCall( + index=0, + type="function", + id=tool_call.call_id, + function=ChoiceDeltaToolCallFunction( + name=tool_call.tool_name, + arguments=json.dumps(tool_call.arguments), + ), + ), + ] + yield chunk + + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = tool_call_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + print(chunk) + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert chunk.event.delta.content == ToolCall( + call_id="tool_call_id", + tool_name="get_flight_info", + arguments={"origin": "AU", "destination": "LAX"}, + ) + def _dummy_chat_completion_chunk(self): return ChatCompletionChunk( id="chatcmpl-123", @@ -269,3 +490,31 @@ class TestConvertStreamChatCompletionResponse: object="chat.completion.chunk", x_groq=None, ) + + def _dummy_chat_completion_chunk_with_tool_call(self): + return ChatCompletionChunk( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + StreamChoice( + index=0, + delta=ChoiceDelta( + role="assistant", + content="Hello World", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + type="function", + function=ChoiceDeltaToolCallFunction( + name="get_flight_info", + arguments='{"origin": "AU", "destination": "LAX"}', + ), + ) + ], + ), + ) + ], + created=1729382400, + object="chat.completion.chunk", + x_groq=None, + ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 02851830b..d0c52c753 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -352,13 +352,13 @@ class TestInference: ): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type in ("remote::groq",): - pytest.skip( - provider.__provider_spec__.provider_type - + " doesn't support tool calling yet" - ) + if ( + provider.__provider_spec__.provider_type == "remote::groq" + and "Llama-3.2" in inference_model + ): + # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better + pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") - inference_impl, _ = inference_stack messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?", @@ -399,11 +399,12 @@ class TestInference: ): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type in ("remote::groq",): - pytest.skip( - provider.__provider_spec__.provider_type - + " doesn't support tool calling yet" - ) + if ( + provider.__provider_spec__.provider_type == "remote::groq" + and "Llama-3.2" in inference_model + ): + # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better + pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") messages = sample_messages + [ UserMessage(