mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 14:49:47 +00:00
Add tool calls to groq inference adapter
This commit is contained in:
parent
78912e663b
commit
cf87262e9c
4 changed files with 400 additions and 60 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
|
|
@ -12,7 +13,15 @@ from groq.types.chat.chat_completion_chunk import (
|
|||
ChoiceDelta,
|
||||
)
|
||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from groq.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from groq.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
|
|
@ -20,6 +29,10 @@ from llama_stack.apis.inference import (
|
|||
StopReason,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolCall,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
|
|
@ -140,12 +153,6 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
assert converted["max_tokens"] == 100
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
def test_includes_temperature(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.temperature = 0.5
|
||||
|
|
@ -162,6 +169,112 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
assert converted["top_p"] == 0.95
|
||||
|
||||
def test_includes_tool_choice(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tool_choice = ToolChoice.required
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["tool_choice"] == "required"
|
||||
|
||||
def test_includes_tools(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.tools = [
|
||||
ToolDefinition(
|
||||
tool_name="get_flight_info",
|
||||
description="Get fight information between two destinations.",
|
||||
parameters={
|
||||
"origin": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The origin airport code. E.g., AU",
|
||||
required=True,
|
||||
),
|
||||
"destination": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The destination airport code. E.g., 'LAX'",
|
||||
required=True,
|
||||
),
|
||||
"passengers": ToolParamDefinition(
|
||||
param_type="array",
|
||||
description="The passengers",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
ToolDefinition(
|
||||
tool_name="log",
|
||||
description="Calulate the logarithm of a number",
|
||||
parameters={
|
||||
"number": ToolParamDefinition(
|
||||
param_type="float",
|
||||
description="The number to calculate the logarithm of",
|
||||
required=True,
|
||||
),
|
||||
"base": ToolParamDefinition(
|
||||
param_type="integer",
|
||||
description="The base of the logarithm",
|
||||
required=False,
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
assert converted["tools"] == [
|
||||
{
|
||||
"type": "function",
|
||||
"function": FunctionDefinition(
|
||||
name="get_flight_info",
|
||||
description="Get fight information between two destinations.",
|
||||
parameters={
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "The origin airport code. E.g., AU",
|
||||
"required": True,
|
||||
},
|
||||
"destination": {
|
||||
"type": "string",
|
||||
"description": "The destination airport code. E.g., 'LAX'",
|
||||
"required": True,
|
||||
},
|
||||
"passengers": {
|
||||
"type": "array",
|
||||
"description": "The passengers",
|
||||
"required": False,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": FunctionDefinition(
|
||||
name="log",
|
||||
description="Calulate the logarithm of a number",
|
||||
parameters={
|
||||
"number": {
|
||||
"type": "float",
|
||||
"description": "The number to calculate the logarithm of",
|
||||
"required": True,
|
||||
},
|
||||
"base": {
|
||||
"type": "integer",
|
||||
"description": "The base of the logarithm",
|
||||
"required": False,
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
def _dummy_chat_completion_request(self):
|
||||
return ChatCompletionRequest(
|
||||
model="Llama-3.2-3B",
|
||||
messages=[UserMessage(content="Hello World")],
|
||||
)
|
||||
|
||||
|
||||
class TestConvertNonStreamChatCompletionResponse:
|
||||
def test_returns_response(self):
|
||||
|
|
@ -188,6 +301,49 @@ class TestConvertNonStreamChatCompletionResponse:
|
|||
|
||||
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
def test_maps_tool_call_to_end_of_message(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
||||
|
||||
def test_converts_multiple_tool_calls(self):
|
||||
response = self._dummy_chat_completion_response_with_tool_call()
|
||||
response.choices[0].message.tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id_2",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="log",
|
||||
arguments='{"number": 10, "base": 2}',
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
converted = convert_chat_completion_response(response)
|
||||
|
||||
assert converted.completion_message.tool_calls == [
|
||||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
),
|
||||
ToolCall(
|
||||
call_id="tool_call_id_2",
|
||||
tool_name="log",
|
||||
arguments={"number": 10, "base": 2},
|
||||
),
|
||||
]
|
||||
|
||||
def _dummy_chat_completion_response(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
|
|
@ -205,6 +361,33 @@ class TestConvertNonStreamChatCompletionResponse:
|
|||
object="chat.completion",
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_response_with_tool_call(self):
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="tool_call_id",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
||||
|
||||
class TestConvertStreamChatCompletionResponse:
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -214,10 +397,6 @@ class TestConvertStreamChatCompletionResponse:
|
|||
for i, message in enumerate(messages):
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
chunk.choices[0].delta.content = message
|
||||
if i == len(messages) - 1:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = None
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk()
|
||||
|
|
@ -241,12 +420,6 @@ class TestConvertStreamChatCompletionResponse:
|
|||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == " !"
|
||||
|
||||
# Dummy chunk to ensure the last chunk is really the end of the stream
|
||||
# This one technically maps to Groq's final "stop" chunk
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||
assert chunk.event.delta == ""
|
||||
|
||||
chunk = await iter.__anext__()
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||
assert chunk.event.delta == ""
|
||||
|
|
@ -255,6 +428,54 @@ class TestConvertStreamChatCompletionResponse:
|
|||
with pytest.raises(StopAsyncIteration):
|
||||
await iter.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tool_calls_stream(self):
|
||||
def tool_call_stream():
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
),
|
||||
ToolCall(
|
||||
call_id="tool_call_id_2",
|
||||
tool_name="log",
|
||||
arguments={"number": 10, "base": 2},
|
||||
),
|
||||
]
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.tool_calls = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
id=tool_call.call_id,
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.arguments),
|
||||
),
|
||||
),
|
||||
]
|
||||
yield chunk
|
||||
|
||||
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||
chunk.choices[0].delta.content = None
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
yield chunk
|
||||
|
||||
stream = tool_call_stream()
|
||||
converted = convert_chat_completion_response_stream(stream)
|
||||
|
||||
iter = converted.__aiter__()
|
||||
chunk = await iter.__anext__()
|
||||
print(chunk)
|
||||
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||
assert chunk.event.delta.content == ToolCall(
|
||||
call_id="tool_call_id",
|
||||
tool_name="get_flight_info",
|
||||
arguments={"origin": "AU", "destination": "LAX"},
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_chunk(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
|
|
@ -269,3 +490,31 @@ class TestConvertStreamChatCompletionResponse:
|
|||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
||||
|
||||
def _dummy_chat_completion_chunk_with_tool_call(self):
|
||||
return ChatCompletionChunk(
|
||||
id="chatcmpl-123",
|
||||
model="Llama-3.2-3B",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=0,
|
||||
delta=ChoiceDelta(
|
||||
role="assistant",
|
||||
content="Hello World",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name="get_flight_info",
|
||||
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1729382400,
|
||||
object="chat.completion.chunk",
|
||||
x_groq=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -352,13 +352,13 @@ class TestInference:
|
|||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||
pytest.skip(
|
||||
provider.__provider_spec__.provider_type
|
||||
+ " doesn't support tool calling yet"
|
||||
)
|
||||
if (
|
||||
provider.__provider_spec__.provider_type == "remote::groq"
|
||||
and "Llama-3.2" in inference_model
|
||||
):
|
||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
content="What's the weather like in San Francisco?",
|
||||
|
|
@ -399,11 +399,12 @@ class TestInference:
|
|||
):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
||||
pytest.skip(
|
||||
provider.__provider_spec__.provider_type
|
||||
+ " doesn't support tool calling yet"
|
||||
)
|
||||
if (
|
||||
provider.__provider_spec__.provider_type == "remote::groq"
|
||||
and "Llama-3.2" in inference_model
|
||||
):
|
||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||
|
||||
messages = sample_messages + [
|
||||
UserMessage(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue