mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 21:39:48 +00:00
Add tool calls to groq inference adapter
This commit is contained in:
parent
78912e663b
commit
cf87262e9c
4 changed files with 400 additions and 60 deletions
|
|
@ -7,6 +7,7 @@
|
|||
import warnings
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import groq
|
||||
from groq import Groq
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
|
|
@ -126,7 +127,14 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
)
|
||||
|
||||
response = self._client.chat.completions.create(**request)
|
||||
try:
|
||||
response = self._client.chat.completions.create(**request)
|
||||
except groq.BadRequestError as e:
|
||||
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
||||
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
||||
raise ValueError("Groq failed to call a tool", e.body.get("error", {}))
|
||||
else:
|
||||
raise e
|
||||
|
||||
if stream:
|
||||
return convert_chat_completion_response_stream(response)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Generator, Literal
|
||||
|
||||
import json
|
||||
from groq import Stream
|
||||
from groq.types.chat.chat_completion import ChatCompletion
|
||||
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||
|
|
@ -20,9 +20,13 @@ from groq.types.chat.chat_completion_system_message_param import (
|
|||
from groq.types.chat.chat_completion_user_message_param import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
from groq.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||
from groq.types.shared.function_definition import FunctionDefinition
|
||||
from groq.types.shared.function_parameters import FunctionParameters
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -33,9 +37,14 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolCallParseStatus,
|
||||
ToolCallDelta,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> CompletionCreateParams:
|
||||
|
|
@ -60,8 +69,8 @@ def convert_chat_completion_request(
|
|||
# so we exclude it for now
|
||||
warnings.warn("repetition_penalty is not supported")
|
||||
|
||||
if request.tools:
|
||||
warnings.warn("tools are not supported yet")
|
||||
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||
|
||||
return CompletionCreateParams(
|
||||
model=request.model,
|
||||
|
|
@ -72,9 +81,10 @@ def convert_chat_completion_request(
|
|||
max_tokens=request.sampling_params.max_tokens or None,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||
)
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||
if message.role == Role.system.value:
|
||||
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||
|
|
@ -88,17 +98,67 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
|||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
|
||||
|
||||
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||
# Groq requires a description for function tools
|
||||
if tool_definition.description is None:
|
||||
raise AssertionError("tool_definition.description is required")
|
||||
|
||||
tool_parameters = tool_definition.parameters or {}
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=tool_definition.tool_name,
|
||||
description=tool_definition.description,
|
||||
parameters={
|
||||
key: _convert_groq_tool_parameter(param)
|
||||
for key, param in tool_parameters.items()
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_groq_tool_parameter(
|
||||
tool_parameter: ToolParamDefinition
|
||||
) -> dict:
|
||||
param = {
|
||||
"type": tool_parameter.param_type,
|
||||
}
|
||||
if tool_parameter.description is not None:
|
||||
param["description"] = tool_parameter.description
|
||||
if tool_parameter.required is not None:
|
||||
param["required"] = tool_parameter.required
|
||||
if tool_parameter.default is not None:
|
||||
param["default"] = tool_parameter.default
|
||||
return param
|
||||
|
||||
|
||||
def convert_chat_completion_response(
|
||||
response: ChatCompletion,
|
||||
) -> ChatCompletionResponse:
|
||||
# groq only supports n=1 at time of writing, so there is only one choice
|
||||
choice = response.choices[0]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
),
|
||||
)
|
||||
if choice.finish_reason == "tool_calls":
|
||||
tool_calls = [
|
||||
_convert_groq_tool_call(tool_call)
|
||||
for tool_call in choice.message.tool_calls
|
||||
]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
stop_reason=StopReason.end_of_message,
|
||||
# Content is not optional
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(
|
||||
|
|
@ -117,7 +177,7 @@ def _map_finish_reason_to_stop_reason(
|
|||
elif finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
elif finish_reason == "tool_calls":
|
||||
raise NotImplementedError("tool_calls is not supported yet")
|
||||
return StopReason.end_of_message
|
||||
else:
|
||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||
|
||||
|
|
@ -139,24 +199,46 @@ async def convert_chat_completion_response_stream(
|
|||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
|
||||
# We assume there's only one finish_reason for the entire stream.
|
||||
# We collect the last finish_reason
|
||||
if choice.finish_reason:
|
||||
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_types),
|
||||
delta=choice.delta.content or "",
|
||||
logprobs=None,
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=choice.delta.content or "",
|
||||
logprobs=None,
|
||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||
)
|
||||
)
|
||||
)
|
||||
elif choice.delta.tool_calls:
|
||||
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
logprobs=None,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
# We assume Groq produces fully formed tool calls for each chunk
|
||||
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_types),
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_types),
|
||||
delta=choice.delta.content or "",
|
||||
logprobs=None,
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
||||
return ToolCall(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
# Note that Groq may return a string that is not valid JSON here
|
||||
# So this may raise a 500 error. Going to leave this as is to see
|
||||
# how big of an issue this is and what we can do about it.
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue