diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 482e6fa97..8618abccf 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,10 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import json import logging from typing import AsyncGenerator, List, Optional, Union +from llama_models.llama3.api import StopReason, ToolCall from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models @@ -30,6 +31,7 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + CompletionMessage, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -40,7 +42,6 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, @@ -68,6 +69,73 @@ def build_model_aliases(): ] +def _convert_to_vllm_tool_calls_in_response( + tool_calls, +) -> List[ToolCall]: + if not tool_calls: + return [] + + call_function_arguments = None + for call in tool_calls: + call_function_arguments = json.loads(call.function.arguments) + + return [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=call_function_arguments, + ) + for call in tool_calls + ] + + +def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: + if tools is None: + return tools + + compat_tools = [] + + for tool in tools: + properties = {} + compat_required = [] + if tool.parameters: + for tool_key, tool_param in tool.parameters.items(): + properties[tool_key] = {"type": tool_param.param_type} + if tool_param.description: + properties[tool_key]["description"] = tool_param.description + if tool_param.default: + properties[tool_key]["default"] = tool_param.default + if tool_param.required: + compat_required.append(tool_key) + + compat_tool = { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": compat_required, + }, + }, + } + + compat_tools.append(compat_tool) + + if len(compat_tools) > 0: + return compat_tools + return None + + +def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) @@ -142,7 +210,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> ChatCompletionResponse: params = await self._get_params(request) r = client.chat.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + choice = r.choices[0] + result = ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content or "", + stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), + tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls), + ), + logprobs=None, + ) + return result async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = await self._get_params(request) @@ -193,6 +270,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): options["max_tokens"] = self.config.max_tokens input_dict = {} + if isinstance(request, ChatCompletionRequest) and request.tools is not None: + input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)} if isinstance(request, ChatCompletionRequest): input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index a3e893d8f..8ee838d84 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -174,6 +174,8 @@ def process_chat_completion_response( ) -> ChatCompletionResponse: choice = response.choices[0] + # TODO: This does not work well with tool calls for vLLM remote provider + # Ref: https://github.com/meta-llama/llama-stack/issues/1058 raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) )