feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063)

# What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant
issues if applicable.]

Closes https://github.com/meta-llama/llama-stack/issues/1046. 

## Test Plan
[Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.*]

```
LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py
================================================================= test session starts =================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/yutang/repos/llama-stack
configfile: pyproject.toml
plugins: anyio-4.8.0
collected 14 items                                                                                                                                    

tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED                  [  7%]
tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED                      [ 14%]
tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%]
tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%]
tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED              [ 35%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED       [ 85%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%]
tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%]

=============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================

```

---------

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-12 09:17:21 -05:00 committed by GitHub
parent bf11cc0450
commit 5e97dd9919
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 45 deletions

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import AsyncGenerator, Dict, List, Optional, Union
@ -14,7 +15,8 @@ from llama_models.datatypes import (
)
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_models.llama3.api.datatypes import StopReason, ToolCall
from openai.types.chat import ChatCompletionMessageToolCall
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
@ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
"role": message.role,
"content": content,
}
class UnparseableToolCall(BaseModel):
"""
A ToolCall with arguments that are not valid JSON.
Mirrors the ToolCall schema, but with arguments as a string.
"""
call_id: str = ""
tool_name: str = ""
arguments: str = ""
def convert_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
"""
Convert a ChatCompletionMessageToolCall tool call to either a
ToolCall or UnparseableToolCall. Returns an UnparseableToolCall
if the tool call is not valid JSON.
"""
try:
arguments = json.loads(tool_call.function.arguments)
except Exception as e:
return UnparseableToolCall(
call_id=tool_call.id or "",
tool_name=tool_call.function.name or "",
arguments=tool_call.function.arguments or "",
)
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=arguments,
)