From 5e97dd991932f7b625fe50d2a5dd6e830cec6346 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Wed, 12 Feb 2025 09:17:21 -0500 Subject: [PATCH] feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Closes https://github.com/meta-llama/llama-stack/issues/1046. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 14 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 7%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 14%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 35%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 85%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%] =============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================ ``` --------- Signed-off-by: Yuan Tang --- .../remote/inference/groq/groq_utils.py | 46 ++------------ .../providers/remote/inference/vllm/vllm.py | 61 ++++++++++++++++++- .../utils/inference/openai_compat.py | 39 +++++++++++- 3 files changed, 101 insertions(+), 45 deletions(-) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 537043d69..d00e5c5a9 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -6,7 +6,7 @@ import json import warnings -from typing import AsyncGenerator, Literal, Union +from typing import AsyncGenerator, Literal from groq import Stream from groq.types.chat.chat_completion import ChatCompletion @@ -15,9 +15,6 @@ from groq.types.chat.chat_completion_assistant_message_param import ( ) from groq.types.chat.chat_completion_chunk import ChatCompletionChunk from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam -from groq.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) from groq.types.chat.chat_completion_system_message_param import ( ChatCompletionSystemMessageParam, ) @@ -30,7 +27,6 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition -from pydantic import BaseModel from llama_stack.apis.common.content_types import ( TextDelta, @@ -52,6 +48,8 @@ from llama_stack.apis.inference import ( ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_strategy_options, + convert_tool_call, + UnparseableToolCall, ) @@ -143,7 +141,7 @@ def convert_chat_completion_response( # groq only supports n=1 at time of writing, so there is only one choice choice = response.choices[0] if choice.finish_reason == "tool_calls": - tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -216,7 +214,7 @@ async def convert_chat_completion_response_stream( warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") # We assume Groq produces fully formed tool calls for each chunk - tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) + tool_call = convert_tool_call(choice.delta.tool_calls[0]) if isinstance(tool_call, ToolCall): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -247,37 +245,3 @@ async def convert_chat_completion_response_stream( ) ) event_type = ChatCompletionResponseEventType.progress - - -class UnparseableToolCall(BaseModel): - """ - A ToolCall with arguments that are not valid JSON. - Mirrors the ToolCall schema, but with arguments as a string. - """ - - call_id: str - tool_name: str - arguments: str - - -def _convert_groq_tool_call( - tool_call: ChatCompletionMessageToolCall, -) -> Union[ToolCall, UnparseableToolCall]: - """ - Convert a Groq tool call to a ToolCall. - Returns an UnparseableToolCall if the tool call is not valid JSON. - """ - try: - arguments = json.loads(tool_call.function.arguments) - except Exception as e: - return UnparseableToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=tool_call.function.arguments, - ) - - return ToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=arguments, - ) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 2e13a6262..02594891b 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -13,7 +13,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -32,6 +32,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, CompletionMessage, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + ChatCompletionResponseEvent, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -42,9 +45,12 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, + OpenAICompatCompletionResponse, + UnparseableToolCall, + convert_tool_call, + process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt, @@ -136,6 +142,51 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) +async def _process_vllm_chat_completion_stream_response( + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], +) -> AsyncGenerator: + event_type = ChatCompletionResponseEventType.start + tool_call_buf = UnparseableToolCall() + async for chunk in stream: + choice = chunk.choices[0] + if choice.finish_reason: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=ToolCall( + call_id=tool_call_buf.call_id, + tool_name=tool_call_buf.tool_name, + arguments=json.loads(tool_call_buf.arguments), + ), + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=choice.delta.content or ""), + logprobs=None, + stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), + ) + ) + elif choice.delta.tool_calls: + tool_call = convert_tool_call(choice.delta.tool_calls[0]) + tool_call_buf.tool_name += tool_call.tool_name + tool_call_buf.call_id += tool_call.call_id + tool_call_buf.arguments += tool_call.arguments + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=TextDelta(text=choice.delta.content or ""), + logprobs=None, + ) + ) + event_type = ChatCompletionResponseEventType.progress + + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) @@ -232,7 +283,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + if len(request.tools) > 0: + res = _process_vllm_chat_completion_stream_response(stream) + else: + res = process_chat_completion_stream_response(stream, self.formatter, request) + async for chunk in res: yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1047c9a58..7480ff2c7 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union @@ -14,7 +15,8 @@ from llama_models.datatypes import ( ) from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason +from llama_models.llama3.api.datatypes import StopReason, ToolCall +from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals "role": message.role, "content": content, } + + +class UnparseableToolCall(BaseModel): + """ + A ToolCall with arguments that are not valid JSON. + Mirrors the ToolCall schema, but with arguments as a string. + """ + + call_id: str = "" + tool_name: str = "" + arguments: str = "" + + +def convert_tool_call( + tool_call: ChatCompletionMessageToolCall, +) -> Union[ToolCall, UnparseableToolCall]: + """ + Convert a ChatCompletionMessageToolCall tool call to either a + ToolCall or UnparseableToolCall. Returns an UnparseableToolCall + if the tool call is not valid JSON. + """ + try: + arguments = json.loads(tool_call.function.arguments) + except Exception as e: + return UnparseableToolCall( + call_id=tool_call.id or "", + tool_name=tool_call.function.name or "", + arguments=tool_call.function.arguments or "", + ) + + return ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=arguments, + )