llama-stack-mirror/llama_stack/providers/remote/inference/groq/groq_utils.py
Aidan Do 39c34dd25f
[#432] Groq Provider tool call tweaks (#811)
# What does this PR do?

Follow up for @ashwinb's comments in
https://github.com/meta-llama/llama-stack/pull/630

- [x] Contributes to issue (#432)


## Test Plan
<details>
<summary>Environment</summary>

```shell
export GROQ_API_KEY=<api-key>

# Create environment if not already
conda create --name llamastack-groq python=3.10
conda activate llamastack-groq

wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/meta-llama/llama-stack/918172c7fa92522c9ebc586bdb4f386b1d9ea224/run.yaml

# Build
pip install -e . && llama stack build --config ./build.yaml --image-type conda

# Activate built environment
conda activate llamastack-groq

# Test deps
pip install pytest pytest_html pytest_asyncio
```
</details>



<details>
<summary>Unit tests</summary>

```shell
# Setup
conda activate llamastack-groq
pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s

# Result
llama_stack/providers/tests/inference/groq/test_groq_utils.py .......................

========================================= 23 passed, 11 warnings in 0.06s =========================================
```
</details>

<details>
<summary>Integration tests</summary>

```shell
# Tests
 pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s

# Results
___________________________ TestInference.test_chat_completion_with_tool_calling[-groq] ___________________________
llama_stack/providers/tests/inference/test_text_inference.py:403: in test_chat_completion_with_tool_calling
    assert len(message.tool_calls) > 0
E   assert 0 > 0
E    +  where 0 = len([])
E    +    where [] = CompletionMessage(role='assistant', content='<function=get_weather>{"location": "San Francisco, CA"}', stop_reason=<StopReason.end_of_turn: 'end_of_turn'>, tool_calls=[]).tool_calls
============================================= short test summary info =============================================
FAILED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-groq] - assert 0 > 0
======================== 1 failed, 3 passed, 5 skipped, 99 deselected, 7 warnings in 2.13s ========================
```

(One failure as expected from 3.2 3B - re:
https://github.com/meta-llama/llama-stack/pull/630#discussion_r1914056503)
</details>

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-01-29 12:02:12 -08:00

293 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import AsyncGenerator, Literal, Union
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
ChatCompletionAssistantMessageParam,
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
"""
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
Warns client if request contains unsupported features.
"""
if request.logprobs:
# Groq doesn't support logprobs at the time of writing
warnings.warn("logprobs are not supported yet")
if request.response_format:
# Groq's JSON mode is beta at the time of writing
warnings.warn("response_format is not supported yet")
if request.sampling_params.repetition_penalty != 1.0:
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
# seem to have different semantics
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
if request.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
sampling_options = get_sampling_strategy_options(request.sampling_params)
return CompletionCreateParams(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
logprobs=None,
frequency_penalty=None,
stream=request.stream,
max_tokens=request.sampling_params.max_tokens or None,
temperature=sampling_options.get("temperature", 1.0),
top_p=sampling_options.get("top_p", 1.0),
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None,
)
def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == "system":
return ChatCompletionSystemMessageParam(role="system", content=message.content)
elif message.role == "user":
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.role == "assistant":
return ChatCompletionAssistantMessageParam(
role="assistant", content=message.content
)
else:
raise ValueError(f"Invalid message role: {message.role}")
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
raise AssertionError("tool_definition.description is required")
tool_parameters = tool_definition.parameters or {}
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_definition.tool_name,
description=tool_definition.description,
parameters={
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
),
)
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
param = {
"type": tool_parameter.param_type,
}
if tool_parameter.description is not None:
param["description"] = tool_parameter.description
if tool_parameter.required is not None:
param["required"] = tool_parameter.required
if tool_parameter.default is not None:
param["default"] = tool_parameter.default
return param
def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
tool_calls = [
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
completion_message=CompletionMessage(
stop_reason=StopReason.end_of_message,
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
)
else:
# Otherwise, return tool calls as normal
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
def _map_finish_reason_to_stop_reason(
finish_reason: Literal["stop", "length", "tool_calls"],
) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls"]
- stop -> model hit a natural stop point or a provided stop sequence
- length -> maximum number of tokens specified in the request was reached
- tool_calls -> model called a tool
"""
if finish_reason == "stop":
return StopReason.end_of_turn
elif finish_reason == "length":
return StopReason.out_of_tokens
elif finish_reason == "tool_calls":
return StopReason.end_of_message
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")
async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
event_type = ChatCompletionResponseEventType.start
for chunk in stream:
choice = chunk.choices[0]
if choice.finish_reason:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
)
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
if isinstance(tool_call, ToolCall):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
)
)
else:
# Otherwise it's an UnparseableToolCall - return the raw tool call
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=tool_call.model_dump_json(),
parse_status=ToolCallParseStatus.failed,
),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress
class UnparseableToolCall(BaseModel):
"""
A ToolCall with arguments that are not valid JSON.
Mirrors the ToolCall schema, but with arguments as a string.
"""
call_id: str
tool_name: str
arguments: str
def _convert_groq_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
"""
Convert a Groq tool call to a ToolCall.
Returns an UnparseableToolCall if the tool call is not valid JSON.
"""
try:
arguments = json.loads(tool_call.function.arguments)
except Exception as e:
return UnparseableToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=arguments,
)