forked from phoenix-oss/llama-stack-mirror
feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063)
# What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Closes https://github.com/meta-llama/llama-stack/issues/1046. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 14 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 7%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 14%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 35%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 85%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%] =============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================ ``` --------- Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
bf11cc0450
commit
5e97dd9919
3 changed files with 101 additions and 45 deletions
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncGenerator, Literal, Union
|
from typing import AsyncGenerator, Literal
|
||||||
|
|
||||||
from groq import Stream
|
from groq import Stream
|
||||||
from groq.types.chat.chat_completion import ChatCompletion
|
from groq.types.chat.chat_completion import ChatCompletion
|
||||||
|
@ -15,9 +15,6 @@ from groq.types.chat.chat_completion_assistant_message_param import (
|
||||||
)
|
)
|
||||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
from groq.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
)
|
|
||||||
from groq.types.chat.chat_completion_system_message_param import (
|
from groq.types.chat.chat_completion_system_message_param import (
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
)
|
)
|
||||||
|
@ -30,7 +27,6 @@ from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
|
@ -52,6 +48,8 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_strategy_options,
|
get_sampling_strategy_options,
|
||||||
|
convert_tool_call,
|
||||||
|
UnparseableToolCall,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,7 +141,7 @@ def convert_chat_completion_response(
|
||||||
# groq only supports n=1 at time of writing, so there is only one choice
|
# groq only supports n=1 at time of writing, so there is only one choice
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
if choice.finish_reason == "tool_calls":
|
if choice.finish_reason == "tool_calls":
|
||||||
tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
|
@ -216,7 +214,7 @@ async def convert_chat_completion_response_stream(
|
||||||
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
||||||
|
|
||||||
# We assume Groq produces fully formed tool calls for each chunk
|
# We assume Groq produces fully formed tool calls for each chunk
|
||||||
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||||
if isinstance(tool_call, ToolCall):
|
if isinstance(tool_call, ToolCall):
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -247,37 +245,3 @@ async def convert_chat_completion_response_stream(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
event_type = ChatCompletionResponseEventType.progress
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
|
|
||||||
class UnparseableToolCall(BaseModel):
|
|
||||||
"""
|
|
||||||
A ToolCall with arguments that are not valid JSON.
|
|
||||||
Mirrors the ToolCall schema, but with arguments as a string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
call_id: str
|
|
||||||
tool_name: str
|
|
||||||
arguments: str
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_groq_tool_call(
|
|
||||||
tool_call: ChatCompletionMessageToolCall,
|
|
||||||
) -> Union[ToolCall, UnparseableToolCall]:
|
|
||||||
"""
|
|
||||||
Convert a Groq tool call to a ToolCall.
|
|
||||||
Returns an UnparseableToolCall if the tool call is not valid JSON.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
arguments = json.loads(tool_call.function.arguments)
|
|
||||||
except Exception as e:
|
|
||||||
return UnparseableToolCall(
|
|
||||||
call_id=tool_call.id,
|
|
||||||
tool_name=tool_call.function.name,
|
|
||||||
arguments=tool_call.function.arguments,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ToolCall(
|
|
||||||
call_id=tool_call.id,
|
|
||||||
tool_name=tool_call.function.name,
|
|
||||||
arguments=arguments,
|
|
||||||
)
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -32,6 +32,9 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
@ -42,9 +45,12 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_stream_response,
|
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
UnparseableToolCall,
|
||||||
|
convert_tool_call,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
@ -136,6 +142,51 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||||
}.get(finish_reason, StopReason.end_of_turn)
|
}.get(finish_reason, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_vllm_chat_completion_stream_response(
|
||||||
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
event_type = ChatCompletionResponseEventType.start
|
||||||
|
tool_call_buf = UnparseableToolCall()
|
||||||
|
async for chunk in stream:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if choice.finish_reason:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
tool_call=ToolCall(
|
||||||
|
call_id=tool_call_buf.call_id,
|
||||||
|
tool_name=tool_call_buf.tool_name,
|
||||||
|
arguments=json.loads(tool_call_buf.arguments),
|
||||||
|
),
|
||||||
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
|
logprobs=None,
|
||||||
|
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif choice.delta.tool_calls:
|
||||||
|
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||||
|
tool_call_buf.tool_name += tool_call.tool_name
|
||||||
|
tool_call_buf.call_id += tool_call.call_id
|
||||||
|
tool_call_buf.arguments += tool_call.arguments
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
|
@ -232,7 +283,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
if len(request.tools) > 0:
|
||||||
|
res = _process_vllm_chat_completion_stream_response(stream)
|
||||||
|
else:
|
||||||
|
res = process_chat_completion_stream_response(stream, self.formatter, request)
|
||||||
|
async for chunk in res:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
@ -14,7 +15,8 @@ from llama_models.datatypes import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
from llama_models.llama3.api.datatypes import StopReason, ToolCall
|
||||||
|
from openai.types.chat import ChatCompletionMessageToolCall
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UnparseableToolCall(BaseModel):
|
||||||
|
"""
|
||||||
|
A ToolCall with arguments that are not valid JSON.
|
||||||
|
Mirrors the ToolCall schema, but with arguments as a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
call_id: str = ""
|
||||||
|
tool_name: str = ""
|
||||||
|
arguments: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tool_call(
|
||||||
|
tool_call: ChatCompletionMessageToolCall,
|
||||||
|
) -> Union[ToolCall, UnparseableToolCall]:
|
||||||
|
"""
|
||||||
|
Convert a ChatCompletionMessageToolCall tool call to either a
|
||||||
|
ToolCall or UnparseableToolCall. Returns an UnparseableToolCall
|
||||||
|
if the tool call is not valid JSON.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_call.function.arguments)
|
||||||
|
except Exception as e:
|
||||||
|
return UnparseableToolCall(
|
||||||
|
call_id=tool_call.id or "",
|
||||||
|
tool_name=tool_call.function.name or "",
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolCall(
|
||||||
|
call_id=tool_call.id,
|
||||||
|
tool_name=tool_call.function.name,
|
||||||
|
arguments=arguments,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue