diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 1124afc7f..924aa0e7d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,10 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import json import logging from typing import AsyncGenerator, List, Optional, Union +from llama_models.llama3.api import StopReason, ToolCall from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models @@ -30,6 +31,7 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + CompletionMessage, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -40,7 +42,6 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, @@ -68,6 +69,73 @@ def build_model_aliases(): ] +def convert_to_vllm_tool_calls_in_response( + tool_calls, +) -> List[ToolCall]: + if not tool_calls: + return [] + + call_function_arguments = None + for call in tool_calls: + call_function_arguments = json.loads(call.function.arguments) + + return [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=call_function_arguments, + ) + for call in tool_calls + ] + + +def convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: + if tools is None: + return tools + + compat_tools = [] + + for tool in tools: + properties = {} + compat_required = [] + if tool.parameters: + for tool_key, tool_param in tool.parameters.items(): + properties[tool_key] = {"type": tool_param.param_type} + if tool_param.description: + properties[tool_key]["description"] = tool_param.description + if tool_param.default: + properties[tool_key]["default"] = tool_param.default + if tool_param.required: + compat_required.append(tool_key) + + compat_tool = { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": compat_required, + }, + }, + } + + compat_tools.append(compat_tool) + + if len(compat_tools) > 0: + return compat_tools + return None + + +def convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) @@ -142,7 +210,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> ChatCompletionResponse: params = await self._get_params(request) r = client.chat.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + choice = r.choices[0] + result = ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content or "", + stop_reason=convert_to_vllm_finish_reason(choice.finish_reason), + tool_calls=convert_to_vllm_tool_calls_in_response(choice.message.tool_calls), + ), + logprobs=None, + ) + return result async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = await self._get_params(request) @@ -187,51 +264,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) return model - def convert_to_vllm_tools(self, tools: List[ToolDefinition]) -> List[dict]: - if tools is None: - return tools - - compat_tools = [] - - for tool in tools: - properties = {} - compat_required = [] - if tool.parameters: - for tool_key, tool_param in tool.parameters.items(): - properties[tool_key] = {"type": tool_param.param_type} - if tool_param.description: - properties[tool_key]["description"] = tool_param.description - if tool_param.default: - properties[tool_key]["default"] = tool_param.default - if tool_param.required: - compat_required.append(tool_key) - - compat_tool = { - "type": "function", - "function": { - "name": tool.tool_name, - "description": tool.description, - "parameters": { - "type": "object", - "properties": properties, - "required": compat_required, - }, - }, - } - - compat_tools.append(compat_tool) - - if len(compat_tools) > 0: - return compat_tools - return None - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: options = get_sampling_options(request.sampling_params) if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens - input_dict = {} - input_dict["tools"] = self.convert_to_vllm_tools(request.tools) + input_dict = {"tools": convert_to_vllm_tools_in_request(request.tools)} if isinstance(request, ChatCompletionRequest): input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index a3e893d8f..1388d14f2 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -174,6 +174,7 @@ def process_chat_completion_response( ) -> ChatCompletionResponse: choice = response.choices[0] + # TODO: This does not work well with tool calls (at least for vLLM remote) raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) )