feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063)

# What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant
issues if applicable.]

Closes https://github.com/meta-llama/llama-stack/issues/1046. 

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

```
LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py
================================================================= test session starts =================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/yutang/repos/llama-stack
configfile: pyproject.toml
plugins: anyio-4.8.0
collected 14 items                                                                                                                                    

tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED                  [  7%]
tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED                      [ 14%]
tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%]
tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%]
tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED              [ 35%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED       [ 85%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%]

=============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================

```

---------

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-12 09:17:21 -05:00 committed by GitHub
parent bf11cc0450
commit 5e97dd9919
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 45 deletions

View file

@ -6,7 +6,7 @@
import json
import warnings
from typing import AsyncGenerator, Literal, Union
from typing import AsyncGenerator, Literal
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
@ -15,9 +15,6 @@ from groq.types.chat.chat_completion_assistant_message_param import (
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
@ -30,7 +27,6 @@ from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
TextDelta,
@ -52,6 +48,8 @@ from llama_stack.apis.inference import (
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
convert_tool_call,
UnparseableToolCall,
)
@ -143,7 +141,7 @@ def convert_chat_completion_response(
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls]
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
@ -216,7 +214,7 @@ async def convert_chat_completion_response_stream(
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
tool_call = convert_tool_call(choice.delta.tool_calls[0])
if isinstance(tool_call, ToolCall):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -247,37 +245,3 @@ async def convert_chat_completion_response_stream(
)
)
event_type = ChatCompletionResponseEventType.progress
class UnparseableToolCall(BaseModel):
"""
A ToolCall with arguments that are not valid JSON.
Mirrors the ToolCall schema, but with arguments as a string.
"""
call_id: str
tool_name: str
arguments: str
def _convert_groq_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
"""
Convert a Groq tool call to a ToolCall.
Returns an UnparseableToolCall if the tool call is not valid JSON.
"""
try:
arguments = json.loads(tool_call.function.arguments)
except Exception as e:
return UnparseableToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=arguments,
)