mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: Support tool calling for non-streaming chat completion in remote vLLM provider (#1034)
# What does this PR do? This PR adds support for tool calling for non-streaming chat completion. Prior to this, tool calls were not passed to chat completion requests and the tools object needs to be restructured properly to be compatible with vLLM provider. ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 12 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 8%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 16%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 25%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 33%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 41%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 58%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 66%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 75%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 83%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] FAILED [ 91%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [100%] ``` --------- Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
24385cfd03
commit
dd37e58868
2 changed files with 84 additions and 3 deletions
|
@ -3,10 +3,11 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.llama3.api import StopReason, ToolCall
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
@ -30,6 +31,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
CompletionMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
@ -40,7 +42,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
|
@ -68,6 +69,73 @@ def build_model_aliases():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_vllm_tool_calls_in_response(
|
||||||
|
tool_calls,
|
||||||
|
) -> List[ToolCall]:
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
call_function_arguments = None
|
||||||
|
for call in tool_calls:
|
||||||
|
call_function_arguments = json.loads(call.function.arguments)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolCall(
|
||||||
|
call_id=call.id,
|
||||||
|
tool_name=call.function.name,
|
||||||
|
arguments=call_function_arguments,
|
||||||
|
)
|
||||||
|
for call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
||||||
|
if tools is None:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
compat_tools = []
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
properties = {}
|
||||||
|
compat_required = []
|
||||||
|
if tool.parameters:
|
||||||
|
for tool_key, tool_param in tool.parameters.items():
|
||||||
|
properties[tool_key] = {"type": tool_param.param_type}
|
||||||
|
if tool_param.description:
|
||||||
|
properties[tool_key]["description"] = tool_param.description
|
||||||
|
if tool_param.default:
|
||||||
|
properties[tool_key]["default"] = tool_param.default
|
||||||
|
if tool_param.required:
|
||||||
|
compat_required.append(tool_key)
|
||||||
|
|
||||||
|
compat_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.tool_name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": compat_required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
compat_tools.append(compat_tool)
|
||||||
|
|
||||||
|
if len(compat_tools) > 0:
|
||||||
|
return compat_tools
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||||
|
return {
|
||||||
|
"stop": StopReason.end_of_turn,
|
||||||
|
"length": StopReason.out_of_tokens,
|
||||||
|
"tool_calls": StopReason.end_of_message,
|
||||||
|
}.get(finish_reason, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
|
@ -142,7 +210,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = client.chat.completions.create(**params)
|
r = client.chat.completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
choice = r.choices[0]
|
||||||
|
result = ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content or "",
|
||||||
|
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
|
||||||
|
tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -193,6 +270,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||||
|
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||||
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||||
|
|
|
@ -174,6 +174,8 @@ def process_chat_completion_response(
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|
||||||
|
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||||
|
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||||
raw_message = formatter.decode_assistant_message_from_content(
|
raw_message = formatter.decode_assistant_message_from_content(
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue