Add tool calls to groq inference adapter

This commit is contained in:
Aidan Do 2024-12-14 22:20:54 +11:00
parent 78912e663b
commit cf87262e9c
4 changed files with 400 additions and 60 deletions

View file

@ -7,6 +7,7 @@
import warnings
from typing import AsyncIterator, List, Optional, Union
import groq
from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
@ -126,7 +127,14 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper):
)
)
response = self._client.chat.completions.create(**request)
try:
response = self._client.chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError("Groq failed to call a tool", e.body.get("error", {}))
else:
raise e
if stream:
return convert_chat_completion_response_stream(response)

View file

@ -6,7 +6,7 @@
import warnings
from typing import AsyncGenerator, Generator, Literal
import json
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
@ -20,9 +20,13 @@ from groq.types.chat.chat_completion_system_message_param import (
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.shared.function_definition import FunctionDefinition
from groq.types.shared.function_parameters import FunctionParameters
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -33,9 +37,14 @@ from llama_stack.apis.inference import (
Message,
Role,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolCallParseStatus,
ToolCallDelta,
ToolPromptFormat,
)
def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
@ -60,8 +69,8 @@ def convert_chat_completion_request(
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
if request.tools:
warnings.warn("tools are not supported yet")
if request.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
return CompletionCreateParams(
model=request.model,
@ -72,9 +81,10 @@ def convert_chat_completion_request(
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None,
)
def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == Role.system.value:
return ChatCompletionSystemMessageParam(role="system", content=message.content)
@ -88,17 +98,67 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
raise ValueError(f"Invalid message role: {message.role}")
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
raise AssertionError("tool_definition.description is required")
tool_parameters = tool_definition.parameters or {}
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_definition.tool_name,
description=tool_definition.description,
parameters={
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
)
)
def _convert_groq_tool_parameter(
tool_parameter: ToolParamDefinition
) -> dict:
param = {
"type": tool_parameter.param_type,
}
if tool_parameter.description is not None:
param["description"] = tool_parameter.description
if tool_parameter.required is not None:
param["required"] = tool_parameter.required
if tool_parameter.default is not None:
param["default"] = tool_parameter.default
return param
def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
if choice.finish_reason == "tool_calls":
tool_calls = [
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
def _map_finish_reason_to_stop_reason(
@ -117,7 +177,7 @@ def _map_finish_reason_to_stop_reason(
elif finish_reason == "length":
return StopReason.out_of_tokens
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
return StopReason.end_of_message
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")
@ -139,24 +199,46 @@ async def convert_chat_completion_response_stream(
for chunk in stream:
choice = chunk.choices[0]
# We assume there's only one finish_reason for the entire stream.
# We collect the last finish_reason
if choice.finish_reason:
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_types),
delta=choice.delta.content or "",
logprobs=None,
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=choice.delta.content or "",
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
)
)
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
logprobs=None,
stop_reason=stop_reason,
)
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_types),
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_types),
delta=choice.delta.content or "",
logprobs=None,
)
)
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
# Note that Groq may return a string that is not valid JSON here
# So this may raise a 500 error. Going to leave this as is to see
# how big of an issue this is and what we can do about it.
arguments=json.loads(tool_call.function.arguments),
)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import pytest
import json
from groq.types.chat.chat_completion import ChatCompletion, Choice
from groq.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
@ -12,7 +13,15 @@ from groq.types.chat.chat_completion_chunk import (
ChoiceDelta,
)
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from groq.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from groq.types.shared.function_definition import FunctionDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
@ -20,6 +29,10 @@ from llama_stack.apis.inference import (
StopReason,
SystemMessage,
UserMessage,
ToolChoice,
ToolDefinition,
ToolParamDefinition,
ToolCall,
)
from llama_stack.providers.remote.inference.groq.groq_utils import (
convert_chat_completion_request,
@ -140,12 +153,6 @@ class TestConvertChatCompletionRequest:
assert converted["max_tokens"] == 100
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
def test_includes_temperature(self):
request = self._dummy_chat_completion_request()
request.sampling_params.temperature = 0.5
@ -162,6 +169,112 @@ class TestConvertChatCompletionRequest:
assert converted["top_p"] == 0.95
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()
request.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request)
assert converted["tool_choice"] == "required"
def test_includes_tools(self):
request = self._dummy_chat_completion_request()
request.tools = [
ToolDefinition(
tool_name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": ToolParamDefinition(
param_type="string",
description="The origin airport code. E.g., AU",
required=True,
),
"destination": ToolParamDefinition(
param_type="string",
description="The destination airport code. E.g., 'LAX'",
required=True,
),
"passengers": ToolParamDefinition(
param_type="array",
description="The passengers",
required=False,
),
},
),
ToolDefinition(
tool_name="log",
description="Calulate the logarithm of a number",
parameters={
"number": ToolParamDefinition(
param_type="float",
description="The number to calculate the logarithm of",
required=True,
),
"base": ToolParamDefinition(
param_type="integer",
description="The base of the logarithm",
required=False,
default=10,
),
},
),
]
converted = convert_chat_completion_request(request)
assert converted["tools"] == [
{
"type": "function",
"function": FunctionDefinition(
name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": {
"type": "string",
"description": "The origin airport code. E.g., AU",
"required": True,
},
"destination": {
"type": "string",
"description": "The destination airport code. E.g., 'LAX'",
"required": True,
},
"passengers": {
"type": "array",
"description": "The passengers",
"required": False,
},
},
),
},
{
"type": "function",
"function": FunctionDefinition(
name="log",
description="Calulate the logarithm of a number",
parameters={
"number": {
"type": "float",
"description": "The number to calculate the logarithm of",
"required": True,
},
"base": {
"type": "integer",
"description": "The base of the logarithm",
"required": False,
"default": 10,
},
},
),
},
]
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
class TestConvertNonStreamChatCompletionResponse:
def test_returns_response(self):
@ -188,6 +301,49 @@ class TestConvertNonStreamChatCompletionResponse:
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
def test_maps_tool_call_to_end_of_message(self):
response = self._dummy_chat_completion_response_with_tool_call()
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_message
def test_converts_multiple_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
),
ChatCompletionMessageToolCall(
id="tool_call_id_2",
type="function",
function=Function(
name="log",
arguments='{"number": 10, "base": 2}',
),
),
]
converted = convert_chat_completion_response(response)
assert converted.completion_message.tool_calls == [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
@ -205,6 +361,33 @@ class TestConvertNonStreamChatCompletionResponse:
object="chat.completion",
)
def _dummy_chat_completion_response_with_tool_call(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
finish_reason="tool_calls",
)
],
created=1729382400,
object="chat.completion",
)
class TestConvertStreamChatCompletionResponse:
@pytest.mark.asyncio
@ -214,10 +397,6 @@ class TestConvertStreamChatCompletionResponse:
for i, message in enumerate(messages):
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = message
if i == len(messages) - 1:
chunk.choices[0].finish_reason = "stop"
else:
chunk.choices[0].finish_reason = None
yield chunk
chunk = self._dummy_chat_completion_chunk()
@ -241,12 +420,6 @@ class TestConvertStreamChatCompletionResponse:
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == " !"
# Dummy chunk to ensure the last chunk is really the end of the stream
# This one technically maps to Groq's final "stop" chunk
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == ""
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
assert chunk.event.delta == ""
@ -255,6 +428,54 @@ class TestConvertStreamChatCompletionResponse:
with pytest.raises(StopAsyncIteration):
await iter.__anext__()
@pytest.mark.asyncio
async def test_returns_tool_calls_stream(self):
def tool_call_stream():
tool_calls = [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
for i, tool_call in enumerate(tool_calls):
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id=tool_call.call_id,
function=ChoiceDeltaToolCallFunction(
name=tool_call.tool_name,
arguments=json.dumps(tool_call.arguments),
),
),
]
yield chunk
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
print(chunk)
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.content == ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
)
def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",
@ -269,3 +490,31 @@ class TestConvertStreamChatCompletionResponse:
object="chat.completion.chunk",
x_groq=None,
)
def _dummy_chat_completion_chunk_with_tool_call(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(
role="assistant",
content="Hello World",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
type="function",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)

View file

@ -352,13 +352,13 @@ class TestInference:
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -399,11 +399,12 @@ class TestInference:
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(