PR tool call followups

This commit is contained in:
Aidan Do 2025-01-18 09:07:42 +11:00
parent 1f60c0286d
commit 76e08cfde0
3 changed files with 126 additions and 30 deletions

View file

@ -6,7 +6,7 @@
import json
import warnings
from typing import AsyncGenerator, Literal
from typing import AsyncGenerator, Literal, Union
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
@ -30,6 +30,8 @@ from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
@ -150,15 +152,26 @@ def convert_chat_completion_response(
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
completion_message=CompletionMessage(
stop_reason=StopReason.end_of_message,
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
)
else:
# Otherwise, return tool calls as normal
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
@ -214,15 +227,27 @@ async def convert_chat_completion_response_stream(
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
if isinstance(tool_call, ToolCall):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
)
)
else:
# Otherwise it's an UnparseableToolCall - return the raw tool call
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call.model_dump_json(),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -234,12 +259,35 @@ async def convert_chat_completion_response_stream(
event_type = ChatCompletionResponseEventType.progress
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
class UnparseableToolCall(BaseModel):
"""
A ToolCall with arguments that are not valid JSON.
Mirrors the ToolCall schema, but with arguments as a string.
"""
call_id: str
tool_name: str
arguments: str
def _convert_groq_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
"""
Convert a Groq tool call to a ToolCall.
Returns an UnparseableToolCall if the tool call is not valid JSON.
"""
try:
arguments = json.loads(tool_call.function.arguments)
except Exception as e:
return UnparseableToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
# Note that Groq may return a string that is not valid JSON here
# So this may raise a 500 error. Going to leave this as is to see
# how big of an issue this is and what we can do about it.
arguments=json.loads(tool_call.function.arguments),
arguments=arguments,
)

View file

@ -23,6 +23,7 @@ from groq.types.chat.chat_completion_message_tool_call import (
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
@ -347,6 +348,26 @@ class TestConvertNonStreamChatCompletionResponse:
),
]
def test_converts_unparseable_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="log",
arguments="(number=10, base=2)",
),
),
]
converted = convert_chat_completion_response(response)
assert (
converted.completion_message.content
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
)
def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
@ -478,6 +499,40 @@ class TestConvertStreamChatCompletionResponse:
arguments={"origin": "AU", "destination": "LAX"},
)
@pytest.mark.asyncio
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
def tool_call_stream():
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id="tool_call_id",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments="(origin=AU, destination=LAX)",
),
),
]
yield chunk
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk
stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)
iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert (
chunk.event.delta.content
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
)
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed
def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",

View file

@ -377,13 +377,6 @@ class TestInference:
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
messages = sample_messages + [
UserMessage(