mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063)
# What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Closes https://github.com/meta-llama/llama-stack/issues/1046. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 14 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 7%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 14%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 35%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 85%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%] =============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================ ``` --------- Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
bf11cc0450
commit
5e97dd9919
3 changed files with 101 additions and 45 deletions
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
|
@ -14,7 +15,8 @@ from llama_models.datatypes import (
|
|||
)
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.datatypes import StopReason, ToolCall
|
||||
from openai.types.chat import ChatCompletionMessageToolCall
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
class UnparseableToolCall(BaseModel):
|
||||
"""
|
||||
A ToolCall with arguments that are not valid JSON.
|
||||
Mirrors the ToolCall schema, but with arguments as a string.
|
||||
"""
|
||||
|
||||
call_id: str = ""
|
||||
tool_name: str = ""
|
||||
arguments: str = ""
|
||||
|
||||
|
||||
def convert_tool_call(
|
||||
tool_call: ChatCompletionMessageToolCall,
|
||||
) -> Union[ToolCall, UnparseableToolCall]:
|
||||
"""
|
||||
Convert a ChatCompletionMessageToolCall tool call to either a
|
||||
ToolCall or UnparseableToolCall. Returns an UnparseableToolCall
|
||||
if the tool call is not valid JSON.
|
||||
"""
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
return UnparseableToolCall(
|
||||
call_id=tool_call.id or "",
|
||||
tool_name=tool_call.function.name or "",
|
||||
arguments=tool_call.function.arguments or "",
|
||||
)
|
||||
|
||||
return ToolCall(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue