forked from phoenix-oss/llama-stack-mirror
# What does this PR do?
Contributes to issue #432
- Adds tool calls to Groq provider
- Enables tool call integration tests
### PR Train
- https://github.com/meta-llama/llama-stack/pull/609
- https://github.com/meta-llama/llama-stack/pull/630 👈
## Test Plan
Environment:
```shell
export GROQ_API_KEY=<api-key>
# build.yaml and run.yaml files
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml
# Create environment if not already
conda create --prefix ./envs python=3.10
conda activate ./envs
# Build
pip install -e . && llama stack build --config ./build.yaml --image-type conda
# Activate built environment
conda activate llamastack-groq
```
<details>
<summary>Unit tests</summary>
```shell
# Setup
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s
# Result
llama_stack/providers/tests/inference/groq/test_groq_utils.py .....................
======================================== 21 passed, 1 warning in 0.05s ========================================
```
</details>
<details>
<summary>Integration tests</summary>
```shell
# Run
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s
# Result
llama_stack/providers/tests/inference/test_text_inference.py .sss.s.ss.sss.s...
========================== 8 passed, 10 skipped, 180 deselected, 7 warnings in 2.73s ==========================
```
</details>
<details>
<summary>Manual</summary>
```bash
llama stack run ./run.yaml --port 5001
```
Via this Jupyter notebook:
9165502582/hello.ipynb
</details>
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [x] Updated relevant documentation. (no relevant documentation it
seems)
- [x] Wrote necessary unit or integration tests.
This commit is contained in:
parent
ace8dd6087
commit
fdcc74fda2
4 changed files with 400 additions and 57 deletions
|
@ -7,6 +7,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
|
import groq
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
from llama_models.datatypes import SamplingParams
|
from llama_models.datatypes import SamplingParams
|
||||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
@ -123,7 +124,16 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self._get_client().chat.completions.create(**request)
|
try:
|
||||||
|
response = self._get_client().chat.completions.create(**request)
|
||||||
|
except groq.BadRequestError as e:
|
||||||
|
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
||||||
|
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
||||||
|
raise ValueError(
|
||||||
|
"Groq failed to call a tool", e.body.get("error", {})
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return convert_chat_completion_response_stream(response)
|
return convert_chat_completion_response_stream(response)
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncGenerator, Literal
|
from typing import AsyncGenerator, Literal
|
||||||
|
|
||||||
|
@ -14,14 +15,20 @@ from groq.types.chat.chat_completion_assistant_message_param import (
|
||||||
)
|
)
|
||||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
|
from groq.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
from groq.types.chat.chat_completion_system_message_param import (
|
from groq.types.chat.chat_completion_system_message_param import (
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
)
|
)
|
||||||
|
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||||
from groq.types.chat.chat_completion_user_message_param import (
|
from groq.types.chat.chat_completion_user_message_param import (
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||||
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -32,6 +39,11 @@ from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
Message,
|
Message,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,8 +71,8 @@ def convert_chat_completion_request(
|
||||||
# so we exclude it for now
|
# so we exclude it for now
|
||||||
warnings.warn("repetition_penalty is not supported")
|
warnings.warn("repetition_penalty is not supported")
|
||||||
|
|
||||||
if request.tools:
|
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||||
warnings.warn("tools are not supported yet")
|
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||||
|
|
||||||
return CompletionCreateParams(
|
return CompletionCreateParams(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
|
@ -71,6 +83,8 @@ def convert_chat_completion_request(
|
||||||
max_tokens=request.sampling_params.max_tokens or None,
|
max_tokens=request.sampling_params.max_tokens or None,
|
||||||
temperature=request.sampling_params.temperature,
|
temperature=request.sampling_params.temperature,
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=request.sampling_params.top_p,
|
||||||
|
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||||
|
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,17 +101,64 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||||
raise ValueError(f"Invalid message role: {message.role}")
|
raise ValueError(f"Invalid message role: {message.role}")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||||
|
# Groq requires a description for function tools
|
||||||
|
if tool_definition.description is None:
|
||||||
|
raise AssertionError("tool_definition.description is required")
|
||||||
|
|
||||||
|
tool_parameters = tool_definition.parameters or {}
|
||||||
|
return ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function=FunctionDefinition(
|
||||||
|
name=tool_definition.tool_name,
|
||||||
|
description=tool_definition.description,
|
||||||
|
parameters={
|
||||||
|
key: _convert_groq_tool_parameter(param)
|
||||||
|
for key, param in tool_parameters.items()
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
|
||||||
|
param = {
|
||||||
|
"type": tool_parameter.param_type,
|
||||||
|
}
|
||||||
|
if tool_parameter.description is not None:
|
||||||
|
param["description"] = tool_parameter.description
|
||||||
|
if tool_parameter.required is not None:
|
||||||
|
param["required"] = tool_parameter.required
|
||||||
|
if tool_parameter.default is not None:
|
||||||
|
param["default"] = tool_parameter.default
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def convert_chat_completion_response(
|
def convert_chat_completion_response(
|
||||||
response: ChatCompletion,
|
response: ChatCompletion,
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
# groq only supports n=1 at time of writing, so there is only one choice
|
# groq only supports n=1 at time of writing, so there is only one choice
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
return ChatCompletionResponse(
|
if choice.finish_reason == "tool_calls":
|
||||||
completion_message=CompletionMessage(
|
tool_calls = [
|
||||||
content=choice.message.content,
|
_convert_groq_tool_call(tool_call)
|
||||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
for tool_call in choice.message.tool_calls
|
||||||
),
|
]
|
||||||
)
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
# Content is not optional
|
||||||
|
content="",
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content,
|
||||||
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _map_finish_reason_to_stop_reason(
|
def _map_finish_reason_to_stop_reason(
|
||||||
|
@ -116,7 +177,7 @@ def _map_finish_reason_to_stop_reason(
|
||||||
elif finish_reason == "length":
|
elif finish_reason == "length":
|
||||||
return StopReason.out_of_tokens
|
return StopReason.out_of_tokens
|
||||||
elif finish_reason == "tool_calls":
|
elif finish_reason == "tool_calls":
|
||||||
raise NotImplementedError("tool_calls is not supported yet")
|
return StopReason.end_of_message
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||||
|
|
||||||
|
@ -129,25 +190,50 @@ async def convert_chat_completion_response_stream(
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
|
|
||||||
# We assume there's only one finish_reason for the entire stream.
|
|
||||||
# We collect the last finish_reason
|
|
||||||
if choice.finish_reason:
|
if choice.finish_reason:
|
||||||
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
event=ChatCompletionResponseEvent(
|
delta=choice.delta.content or "",
|
||||||
event_type=event_type,
|
logprobs=None,
|
||||||
delta=choice.delta.content or "",
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
logprobs=None,
|
)
|
||||||
|
)
|
||||||
|
elif choice.delta.tool_calls:
|
||||||
|
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||||
|
if len(choice.delta.tool_calls) > 1:
|
||||||
|
warnings.warn(
|
||||||
|
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We assume Groq produces fully formed tool calls for each chunk
|
||||||
|
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=choice.delta.content or "",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
event_type = ChatCompletionResponseEventType.progress
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
return ToolCall(
|
||||||
delta="",
|
call_id=tool_call.id,
|
||||||
logprobs=None,
|
tool_name=tool_call.function.name,
|
||||||
stop_reason=stop_reason,
|
# Note that Groq may return a string that is not valid JSON here
|
||||||
)
|
# So this may raise a 500 error. Going to leave this as is to see
|
||||||
|
# how big of an issue this is and what we can do about it.
|
||||||
|
arguments=json.loads(tool_call.function.arguments),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,21 +4,33 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
from groq.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
from groq.types.chat.chat_completion_chunk import (
|
from groq.types.chat.chat_completion_chunk import (
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
Choice as StreamChoice,
|
Choice as StreamChoice,
|
||||||
ChoiceDelta,
|
ChoiceDelta,
|
||||||
|
ChoiceDeltaToolCall,
|
||||||
|
ChoiceDeltaToolCallFunction,
|
||||||
)
|
)
|
||||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||||
|
from groq.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
Function,
|
||||||
|
)
|
||||||
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
ToolCall,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
from llama_stack.providers.remote.inference.groq.groq_utils import (
|
||||||
|
@ -140,12 +152,6 @@ class TestConvertChatCompletionRequest:
|
||||||
|
|
||||||
assert converted["max_tokens"] == 100
|
assert converted["max_tokens"] == 100
|
||||||
|
|
||||||
def _dummy_chat_completion_request(self):
|
|
||||||
return ChatCompletionRequest(
|
|
||||||
model="Llama-3.2-3B",
|
|
||||||
messages=[UserMessage(content="Hello World")],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_includes_temperature(self):
|
def test_includes_temperature(self):
|
||||||
request = self._dummy_chat_completion_request()
|
request = self._dummy_chat_completion_request()
|
||||||
request.sampling_params.temperature = 0.5
|
request.sampling_params.temperature = 0.5
|
||||||
|
@ -162,6 +168,112 @@ class TestConvertChatCompletionRequest:
|
||||||
|
|
||||||
assert converted["top_p"] == 0.95
|
assert converted["top_p"] == 0.95
|
||||||
|
|
||||||
|
def test_includes_tool_choice(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.tool_choice = ToolChoice.required
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["tool_choice"] == "required"
|
||||||
|
|
||||||
|
def test_includes_tools(self):
|
||||||
|
request = self._dummy_chat_completion_request()
|
||||||
|
request.tools = [
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_flight_info",
|
||||||
|
description="Get fight information between two destinations.",
|
||||||
|
parameters={
|
||||||
|
"origin": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The origin airport code. E.g., AU",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"destination": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The destination airport code. E.g., 'LAX'",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"passengers": ToolParamDefinition(
|
||||||
|
param_type="array",
|
||||||
|
description="The passengers",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="log",
|
||||||
|
description="Calulate the logarithm of a number",
|
||||||
|
parameters={
|
||||||
|
"number": ToolParamDefinition(
|
||||||
|
param_type="float",
|
||||||
|
description="The number to calculate the logarithm of",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"base": ToolParamDefinition(
|
||||||
|
param_type="integer",
|
||||||
|
description="The base of the logarithm",
|
||||||
|
required=False,
|
||||||
|
default=10,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
converted = convert_chat_completion_request(request)
|
||||||
|
|
||||||
|
assert converted["tools"] == [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": FunctionDefinition(
|
||||||
|
name="get_flight_info",
|
||||||
|
description="Get fight information between two destinations.",
|
||||||
|
parameters={
|
||||||
|
"origin": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The origin airport code. E.g., AU",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"destination": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The destination airport code. E.g., 'LAX'",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"passengers": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "The passengers",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": FunctionDefinition(
|
||||||
|
name="log",
|
||||||
|
description="Calulate the logarithm of a number",
|
||||||
|
parameters={
|
||||||
|
"number": {
|
||||||
|
"type": "float",
|
||||||
|
"description": "The number to calculate the logarithm of",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"base": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The base of the logarithm",
|
||||||
|
"required": False,
|
||||||
|
"default": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def _dummy_chat_completion_request(self):
|
||||||
|
return ChatCompletionRequest(
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
messages=[UserMessage(content="Hello World")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestConvertNonStreamChatCompletionResponse:
|
class TestConvertNonStreamChatCompletionResponse:
|
||||||
def test_returns_response(self):
|
def test_returns_response(self):
|
||||||
|
@ -188,6 +300,49 @@ class TestConvertNonStreamChatCompletionResponse:
|
||||||
|
|
||||||
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
assert converted.completion_message.stop_reason == StopReason.out_of_tokens
|
||||||
|
|
||||||
|
def test_maps_tool_call_to_end_of_message(self):
|
||||||
|
response = self._dummy_chat_completion_response_with_tool_call()
|
||||||
|
|
||||||
|
converted = convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
assert converted.completion_message.stop_reason == StopReason.end_of_message
|
||||||
|
|
||||||
|
def test_converts_multiple_tool_calls(self):
|
||||||
|
response = self._dummy_chat_completion_response_with_tool_call()
|
||||||
|
response.choices[0].message.tool_calls = [
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="tool_call_id",
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_flight_info",
|
||||||
|
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="tool_call_id_2",
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="log",
|
||||||
|
arguments='{"number": 10, "base": 2}',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
converted = convert_chat_completion_response(response)
|
||||||
|
|
||||||
|
assert converted.completion_message.tool_calls == [
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name="get_flight_info",
|
||||||
|
arguments={"origin": "AU", "destination": "LAX"},
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id_2",
|
||||||
|
tool_name="log",
|
||||||
|
arguments={"number": 10, "base": 2},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
def _dummy_chat_completion_response(self):
|
def _dummy_chat_completion_response(self):
|
||||||
return ChatCompletion(
|
return ChatCompletion(
|
||||||
id="chatcmpl-123",
|
id="chatcmpl-123",
|
||||||
|
@ -205,6 +360,33 @@ class TestConvertNonStreamChatCompletionResponse:
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _dummy_chat_completion_response_with_tool_call(self):
|
||||||
|
return ChatCompletion(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="tool_call_id",
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_flight_info",
|
||||||
|
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1729382400,
|
||||||
|
object="chat.completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestConvertStreamChatCompletionResponse:
|
class TestConvertStreamChatCompletionResponse:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -214,10 +396,6 @@ class TestConvertStreamChatCompletionResponse:
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
chunk = self._dummy_chat_completion_chunk()
|
chunk = self._dummy_chat_completion_chunk()
|
||||||
chunk.choices[0].delta.content = message
|
chunk.choices[0].delta.content = message
|
||||||
if i == len(messages) - 1:
|
|
||||||
chunk.choices[0].finish_reason = "stop"
|
|
||||||
else:
|
|
||||||
chunk.choices[0].finish_reason = None
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
chunk = self._dummy_chat_completion_chunk()
|
chunk = self._dummy_chat_completion_chunk()
|
||||||
|
@ -241,12 +419,6 @@ class TestConvertStreamChatCompletionResponse:
|
||||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
||||||
assert chunk.event.delta == " !"
|
assert chunk.event.delta == " !"
|
||||||
|
|
||||||
# Dummy chunk to ensure the last chunk is really the end of the stream
|
|
||||||
# This one technically maps to Groq's final "stop" chunk
|
|
||||||
chunk = await iter.__anext__()
|
|
||||||
assert chunk.event.event_type == ChatCompletionResponseEventType.progress
|
|
||||||
assert chunk.event.delta == ""
|
|
||||||
|
|
||||||
chunk = await iter.__anext__()
|
chunk = await iter.__anext__()
|
||||||
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
assert chunk.event.event_type == ChatCompletionResponseEventType.complete
|
||||||
assert chunk.event.delta == ""
|
assert chunk.event.delta == ""
|
||||||
|
@ -255,6 +427,53 @@ class TestConvertStreamChatCompletionResponse:
|
||||||
with pytest.raises(StopAsyncIteration):
|
with pytest.raises(StopAsyncIteration):
|
||||||
await iter.__anext__()
|
await iter.__anext__()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_tool_calls_stream(self):
|
||||||
|
def tool_call_stream():
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name="get_flight_info",
|
||||||
|
arguments={"origin": "AU", "destination": "LAX"},
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
call_id="tool_call_id_2",
|
||||||
|
tool_name="log",
|
||||||
|
arguments={"number": 10, "base": 2},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for i, tool_call in enumerate(tool_calls):
|
||||||
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||||
|
chunk.choices[0].delta.tool_calls = [
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
id=tool_call.call_id,
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
arguments=json.dumps(tool_call.arguments),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
chunk = self._dummy_chat_completion_chunk_with_tool_call()
|
||||||
|
chunk.choices[0].delta.content = None
|
||||||
|
chunk.choices[0].finish_reason = "stop"
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = tool_call_stream()
|
||||||
|
converted = convert_chat_completion_response_stream(stream)
|
||||||
|
|
||||||
|
iter = converted.__aiter__()
|
||||||
|
chunk = await iter.__anext__()
|
||||||
|
assert chunk.event.event_type == ChatCompletionResponseEventType.start
|
||||||
|
assert chunk.event.delta.content == ToolCall(
|
||||||
|
call_id="tool_call_id",
|
||||||
|
tool_name="get_flight_info",
|
||||||
|
arguments={"origin": "AU", "destination": "LAX"},
|
||||||
|
)
|
||||||
|
|
||||||
def _dummy_chat_completion_chunk(self):
|
def _dummy_chat_completion_chunk(self):
|
||||||
return ChatCompletionChunk(
|
return ChatCompletionChunk(
|
||||||
id="chatcmpl-123",
|
id="chatcmpl-123",
|
||||||
|
@ -269,3 +488,31 @@ class TestConvertStreamChatCompletionResponse:
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
x_groq=None,
|
x_groq=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _dummy_chat_completion_chunk_with_tool_call(self):
|
||||||
|
return ChatCompletionChunk(
|
||||||
|
id="chatcmpl-123",
|
||||||
|
model="Llama-3.2-3B",
|
||||||
|
choices=[
|
||||||
|
StreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=ChoiceDelta(
|
||||||
|
role="assistant",
|
||||||
|
content="Hello World",
|
||||||
|
tool_calls=[
|
||||||
|
ChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
function=ChoiceDeltaToolCallFunction(
|
||||||
|
name="get_flight_info",
|
||||||
|
arguments='{"origin": "AU", "destination": "LAX"}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1729382400,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
x_groq=None,
|
||||||
|
)
|
||||||
|
|
|
@ -375,13 +375,13 @@ class TestInference:
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
if (
|
||||||
pytest.skip(
|
provider.__provider_spec__.provider_type == "remote::groq"
|
||||||
provider.__provider_spec__.provider_type
|
and "Llama-3.2" in inference_model
|
||||||
+ " doesn't support tool calling yet"
|
):
|
||||||
)
|
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||||
|
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||||
|
|
||||||
inference_impl, _ = inference_stack
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What's the weather like in San Francisco?",
|
content="What's the weather like in San Francisco?",
|
||||||
|
@ -422,11 +422,12 @@ class TestInference:
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
if provider.__provider_spec__.provider_type in ("remote::groq",):
|
if (
|
||||||
pytest.skip(
|
provider.__provider_spec__.provider_type == "remote::groq"
|
||||||
provider.__provider_spec__.provider_type
|
and "Llama-3.2" in inference_model
|
||||||
+ " doesn't support tool calling yet"
|
):
|
||||||
)
|
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
||||||
|
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
||||||
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -444,7 +445,6 @@ class TestInference:
|
||||||
**common_params,
|
**common_params,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
assert all(
|
assert all(
|
||||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue