llama-stack/llama_stack/providers/tests/inference/groq/test_groq_utils.py
Aidan Do fdcc74fda2
[#432] Add Groq Provider - tool calls (#630)
# What does this PR do?

Contributes to issue #432

- Adds tool calls to Groq provider
- Enables tool call integration tests

### PR Train

- https://github.com/meta-llama/llama-stack/pull/609 
- https://github.com/meta-llama/llama-stack/pull/630 👈

## Test Plan
Environment:

```shell
export GROQ_API_KEY=<api-key>

# build.yaml and run.yaml files
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml

# Create environment if not already
conda create --prefix ./envs python=3.10
conda activate ./envs

# Build
pip install -e . && llama stack build --config ./build.yaml --image-type conda

# Activate built environment
conda activate llamastack-groq
```

<details>
<summary>Unit tests</summary>

```shell
# Setup
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s

# Result
llama_stack/providers/tests/inference/groq/test_groq_utils.py .....................

======================================== 21 passed, 1 warning in 0.05s ========================================
```
</details>

<details>
<summary>Integration tests</summary>

```shell
# Run
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s

# Result
llama_stack/providers/tests/inference/test_text_inference.py .sss.s.ss.sss.s...

========================== 8 passed, 10 skipped, 180 deselected, 7 warnings in 2.73s ==========================
```
</details>

<details>
<summary>Manual</summary>

```bash
llama stack run ./run.yaml --port 5001
```

Via this Jupyter notebook:
9165502582/hello.ipynb
</details>

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [x] Updated relevant documentation. (no relevant documentation it
seems)
- [x] Wrote necessary unit or integration tests.
2025-01-13 18:17:38 -08:00

518 lines
18 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from groq.types.chat.chat_completion import ChatCompletion, Choice
from groq.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice as StreamChoice,
ChoiceDelta,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.providers.remote.inference.groq.groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
class TestConvertChatCompletionRequest:
def test_sets_model(self):
request = self._dummy_chat_completion_request()
request.model = "Llama-3.2-3B"
converted = convert_chat_completion_request(request)
assert converted["model"] == "Llama-3.2-3B"
def test_converts_user_message(self):
request = self._dummy_chat_completion_request()
request.messages = [UserMessage(content="Hello World")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
]
def test_converts_system_message(self):
request = self._dummy_chat_completion_request()
request.messages = [SystemMessage(content="You are a helpful assistant.")]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "system", "content": "You are a helpful assistant."},
]
def test_converts_completion_message(self):
request = self._dummy_chat_completion_request()
request.messages = [
UserMessage(content="Hello World"),
CompletionMessage(
content="Hello World! How can I help you today?",
stop_reason=StopReason.end_of_message,
),
]
converted = convert_chat_completion_request(request)
assert converted["messages"] == [
{"role": "user", "content": "Hello World"},
{"role": "assistant", "content": "Hello World! How can I help you today?"},
]
def test_does_not_include_logprobs(self):
request = self._dummy_chat_completion_request()
request.logprobs = True
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "logprobs are not supported yet" in warnings[0].message.args[0]
assert converted.get("logprobs") is None
def test_does_not_include_response_format(self):
request = self._dummy_chat_completion_request()
request.response_format = {
"type": "json_object",
"json_schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
},
},
}
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "response_format is not supported yet" in warnings[0].message.args[0]
assert converted.get("response_format") is None
def test_does_not_include_repetition_penalty(self):
request = self._dummy_chat_completion_request()
request.sampling_params.repetition_penalty = 1.5
with pytest.warns(Warning) as warnings:
converted = convert_chat_completion_request(request)
assert "repetition_penalty is not supported" in warnings[0].message.args[0]
assert converted.get("repetition_penalty") is None
assert converted.get("frequency_penalty") is None
def test_includes_stream(self):
request = self._dummy_chat_completion_request()
request.stream = True
converted = convert_chat_completion_request(request)
assert converted["stream"] is True
def test_if_max_tokens_is_0_then_it_is_not_included(self):
request = self._dummy_chat_completion_request()
# 0 is the default value for max_tokens
# So we assume that if it's 0, the user didn't set it
request.sampling_params.max_tokens = 0
converted = convert_chat_completion_request(request)
assert converted.get("max_tokens") is None
def test_includes_max_tokens_if_set(self):
request = self._dummy_chat_completion_request()
request.sampling_params.max_tokens = 100
converted = convert_chat_completion_request(request)
assert converted["max_tokens"] == 100
def test_includes_temperature(self):
request = self._dummy_chat_completion_request()
request.sampling_params.temperature = 0.5
converted = convert_chat_completion_request(request)
assert converted["temperature"] == 0.5
def test_includes_top_p(self):
request = self._dummy_chat_completion_request()
request.sampling_params.top_p = 0.95
converted = convert_chat_completion_request(request)
assert converted["top_p"] == 0.95
def test_includes_tool_choice(self):
request = self._dummy_chat_completion_request()
request.tool_choice = ToolChoice.required
converted = convert_chat_completion_request(request)
assert converted["tool_choice"] == "required"
def test_includes_tools(self):
request = self._dummy_chat_completion_request()
request.tools = [
ToolDefinition(
tool_name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": ToolParamDefinition(
param_type="string",
description="The origin airport code. E.g., AU",
required=True,
),
"destination": ToolParamDefinition(
param_type="string",
description="The destination airport code. E.g., 'LAX'",
required=True,
),
"passengers": ToolParamDefinition(
param_type="array",
description="The passengers",
required=False,
),
},
),
ToolDefinition(
tool_name="log",
description="Calulate the logarithm of a number",
parameters={
"number": ToolParamDefinition(
param_type="float",
description="The number to calculate the logarithm of",
required=True,
),
"base": ToolParamDefinition(
param_type="integer",
description="The base of the logarithm",
required=False,
default=10,
),
},
),
]
converted = convert_chat_completion_request(request)
assert converted["tools"] == [
{
"type": "function",
"function": FunctionDefinition(
name="get_flight_info",
description="Get fight information between two destinations.",
parameters={
"origin": {
"type": "string",
"description": "The origin airport code. E.g., AU",
"required": True,
},
"destination": {
"type": "string",
"description": "The destination airport code. E.g., 'LAX'",
"required": True,
},
"passengers": {
"type": "array",
"description": "The passengers",
"required": False,
},
},
),
},
{
"type": "function",
"function": FunctionDefinition(
name="log",
description="Calulate the logarithm of a number",
parameters={
"number": {
"type": "float",
"description": "The number to calculate the logarithm of",
"required": True,
},
"base": {
"type": "integer",
"description": "The base of the logarithm",
"required": False,
"default": 10,
},
},
),
},
]
def _dummy_chat_completion_request(self):
return ChatCompletionRequest(
model="Llama-3.2-3B",
messages=[UserMessage(content="Hello World")],
)
class TestConvertNonStreamChatCompletionResponse:
def test_returns_response(self):
response = self._dummy_chat_completion_response()
response.choices[0].message.content = "Hello World"
converted = convert_chat_completion_response(response)
assert converted.completion_message.content == "Hello World"
def test_maps_stop_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "stop"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_turn
def test_maps_length_to_end_of_message(self):
response = self._dummy_chat_completion_response()
response.choices[0].finish_reason = "length"
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
def test_maps_tool_call_to_end_of_message(self):
response = self._dummy_chat_completion_response_with_tool_call()
converted = convert_chat_completion_response(response)
assert converted.completion_message.stop_reason == StopReason.end_of_message
def test_converts_multiple_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
),
ChatCompletionMessageToolCall(
id="tool_call_id_2",
type="function",
function=Function(
name="log",
arguments='{"number": 10, "base": 2}',
),
),
]
converted = convert_chat_completion_response(response)
assert converted.completion_message.tool_calls == [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Hello World"
),
finish_reason="stop",
)
],
created=1729382400,
object="chat.completion",
)
def _dummy_chat_completion_response_with_tool_call(self):
return ChatCompletion(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
finish_reason="tool_calls",
)
],
created=1729382400,
object="chat.completion",
)
class TestConvertStreamChatCompletionResponse:
@pytest.mark.asyncio
async def test_returns_stream(self):
def chat_completion_stream():
messages = ["Hello ", "World ", " !"]
for i, message in enumerate(messages):
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = message
yield chunk
chunk = self._dummy_chat_completion_chunk()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = chat_completion_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta == "Hello "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == "World "
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
assert chunk.event.delta == " !"
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
assert chunk.event.delta == ""
assert chunk.event.stop_reason == StopReason.end_of_turn
with pytest.raises(StopAsyncIteration):
await iter.__anext__()
@pytest.mark.asyncio
async def test_returns_tool_calls_stream(self):
def tool_call_stream():
tool_calls = [
ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
),
ToolCall(
call_id="tool_call_id_2",
tool_name="log",
arguments={"number": 10, "base": 2},
),
]
for i, tool_call in enumerate(tool_calls):
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id=tool_call.call_id,
function=ChoiceDeltaToolCallFunction(
name=tool_call.tool_name,
arguments=json.dumps(tool_call.arguments),
),
),
]
yield chunk
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert chunk.event.delta.content == ToolCall(
call_id="tool_call_id",
tool_name="get_flight_info",
arguments={"origin": "AU", "destination": "LAX"},
)
def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(role="assistant", content="Hello World"),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)
def _dummy_chat_completion_chunk_with_tool_call(self):
return ChatCompletionChunk(
id="chatcmpl-123",
model="Llama-3.2-3B",
choices=[
StreamChoice(
index=0,
delta=ChoiceDelta(
role="assistant",
content="Hello World",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
type="function",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments='{"origin": "AU", "destination": "LAX"}',
),
)
],
),
)
],
created=1729382400,
object="chat.completion.chunk",
x_groq=None,
)