diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py new file mode 100644 index 000000000..57144b999 --- /dev/null +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Optional + +import vllm + +from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + GrammarResponseFormat, + JsonSchemaResponseFormat, + Message, + ToolChoice, + UserMessage, +) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict, + get_sampling_options, +) + + +############################################################################### +# This file contains OpenAI compatibility code that is currently only used +# by the inline vLLM connector. Some or all of this code may be moved to a +# central location at a later date. + + +def _merge_context_into_content(message: Message) -> Message: # type: ignore + """ + Merge the ``context`` field of a Llama Stack ``Message`` object into + the content field for compabilitiy with OpenAI-style APIs. + + Generates a content string that emulates the current behavior + of ``llama_models.llama3.api.chat_format.encode_message()``. + + :param message: Message that may include ``context`` field + + :returns: A version of ``message`` with any context merged into the + ``content`` field. + """ + if not isinstance(message, UserMessage): # Separate type check for linter + return message + if message.context is None: + return message + return UserMessage( + role=message.role, + # Emumate llama_models.llama3.api.chat_format.encode_message() + content=message.content + "\n\n" + message.context, + context=None, + ) + + +def _llama_stack_tools_to_openai_tools( + tools: Optional[List[ToolDefinition]] = None, +) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]: + """ + Convert the list of available tools from Llama Stack's format to vLLM's + version of OpenAI's format. + """ + if tools is None: + return [] + + result = [] + for t in tools: + if isinstance(t.tool_name, BuiltinTool): + raise NotImplementedError("Built-in tools not yet implemented") + if t.parameters is None: + parameters = None + else: # if t.parameters is not None + # Convert the "required" flags to a list of required params + required_params = [k for k, v in t.parameters.items() if v.required] + parameters = { + "type": "object", # Mystery value that shows up in OpenAI docs + "properties": { + k: {"type": v.param_type, "description": v.description} + for k, v in t.parameters.items() + }, + "required": required_params, + } + + function_def = vllm.entrypoints.openai.protocol.FunctionDefinition( + name=t.tool_name, description=t.description, parameters=parameters + ) + + # Every tool definition is double-boxed in a ChatCompletionToolsParam + result.append( + vllm.entrypoints.openai.protocol.ChatCompletionToolsParam( + function=function_def + ) + ) + return result + + +async def llama_stack_chat_completion_to_openai_chat_completion_dict( + request: ChatCompletionRequest, +) -> dict: + """ + Convert a chat completion request in Llama Stack format into an + equivalent set of arguments to pass to an OpenAI-compatible + chat completions API. + + :param request: Bundled request parameters in Llama Stack format. + + :returns: Dictionary of key-value pairs to use as an initializer + for a dataclass or to be converted directly to JSON and sent + over the wire. + """ + + converted_messages = [ + # This mystery async call makes the parent function also be async + await convert_message_to_openai_dict( + _merge_context_into_content(m), download=True + ) + for m in request.messages + ] + converted_tools = _llama_stack_tools_to_openai_tools(request.tools) + + # Llama will try to use built-in tools with no tool catalog, so don't enable + # tool choice unless at least one tool is enabled. + converted_tool_choice = "none" + if ( + request.tool_choice == ToolChoice.auto + and request.tools is not None + and len(request.tools) > 0 + ): + converted_tool_choice = "auto" + + # TODO: Figure out what to do with the tool_prompt_format argument. + # Other connectors appear to drop it quietly. + + # Use Llama Stack shared code to translate sampling parameters. + sampling_options = get_sampling_options(request.sampling_params) + + # get_sampling_options() translates repetition penalties to an option that + # OpenAI's APIs don't know about. + # vLLM's OpenAI-compatible API also handles repetition penalties wrong. + # For now, translate repetition penalties into a format that vLLM's broken + # API will handle correctly. Two wrongs make a right... + if "repeat_penalty" in sampling_options: + del sampling_options["repeat_penalty"] + if ( + request.sampling_params.repetition_penalty is not None + and request.sampling_params.repetition_penalty != 1.0 + ): + sampling_options["repetition_penalty"] = ( + request.sampling_params.repetition_penalty + ) + + # Convert a single response format into four different parameters, per + # the OpenAI spec + guided_decoding_options = dict() + if request.response_format is None: + # Use defaults + pass + elif isinstance(request.response_format, JsonSchemaResponseFormat): + guided_decoding_options["guided_json"] = request.response_format.json_schema + elif isinstance(request.response_format, GrammarResponseFormat): + guided_decoding_options["guided_grammar"] = request.response_format.bnf + else: + raise TypeError( + f"ResponseFormat object is of unexpected " + f"subtype '{type(request.response_format)}'" + ) + + logprob_options = dict() + if request.logprobs is not None: + logprob_options["logprobs"] = request.logprobs.top_k + + # Marshall together all the arguments for a ChatCompletionRequest + request_options = { + "model": request.model, + "messages": converted_messages, + "tools": converted_tools, + "tool_choice": converted_tool_choice, + "stream": request.stream, + } + request_options.update(sampling_options) + request_options.update(guided_decoding_options) + request_options.update(logprob_options) + + return request_options diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 265c2ab78..bf95b88a9 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -16,6 +16,19 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import vllm.entrypoints.openai.protocol import vllm.sampling_params +############################################################################ +# llama_models imports go here +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import ( + SamplingParams, + StopReason, + ToolDefinition, + ToolPromptFormat, + TopKSamplingStrategy, + TopPSamplingStrategy, +) +from llama_models.llama3.api.tokenizer import Tokenizer + ############################################################################ # vLLM imports go here # @@ -33,6 +46,7 @@ from llama_stack.apis.common.content_types import ( ToolCallDelta, ) from llama_stack.apis.inference import ( + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, @@ -50,9 +64,6 @@ from llama_stack.apis.inference import ( TokenLogProbs, ToolCall, ToolChoice, - ToolConfig, - ToolDefinition, - ToolPromptFormat, ) from llama_stack.apis.models import Model from llama_stack.models.llama.llama3.chat_format import ChatFormat @@ -64,27 +75,12 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ModelsProtocolPrivate, ) -from llama_stack.providers.utils.inference.openai_compat import ( - GrammarResponseFormat, - Inference, - JsonSchemaResponseFormat, - LogProbConfig, - Message, - OpenAICompatCompletionChoice, - OpenAICompatCompletionResponse, - ResponseFormat, - ToolCall, - ToolChoice, - UserMessage, - convert_message_to_openai_dict, - get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, -) +from llama_stack.providers.utils.inference.openai_compat import get_stop_reason ############################################################################ # Package-local imports go here from .config import VLLMConfig +from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict ############################################################################ # Constants go here @@ -119,54 +115,12 @@ def _random_uuid_str() -> str: return str(uuid.uuid4().hex) -def _merge_context_into_content(message: Message) -> Message: # type: ignore - """ - Merge the ``context`` field of a Llama Stack ``Message`` object into - the content field for compabilitiy with OpenAI-style APIs. - - Generates a content string that emulates the current behavior - of ``llama_models.llama3.api.chat_format.encode_message()``. - - :param message: Message that may include ``context`` field - - :returns: A version of ``message`` with any context merged into the - ``content`` field. - """ - if not isinstance(message, UserMessage): # Separate type check for linter - return message - if message.context is None: - return message - return UserMessage( - role=message.role, - # Emumate llama_models.llama3.api.chat_format.encode_message() - content=message.content + "\n\n" + message.context, - context=None, - ) - - -def _convert_finish_reason(finish_reason: str | None) -> str | None: - """Convert an OpenAI "finish_reason" result to the equivalent - Llama Stack result code. - """ - # This conversion is currently a wild guess. - if finish_reason is None: - return None - elif finish_reason == "stop": - return StopReason.end_of_turn - else: - return StopReason.out_of_tokens - - def _response_format_to_guided_decoding_params( response_format: Optional[ResponseFormat], # type: ignore ) -> vllm.sampling_params.GuidedDecodingParams: """ - Like Llama Stack, vLLM's OpenAI-compatible API also uses the name - "ResponseFormat" to describe the object that is a wrapper around - another object that is a wrapper around another object inside - someone else's constrained decoding library. - Here we translate from Llama Stack's wrapper code to vLLM's code - that does the same. + Translate constrained decoding parameters from Llama Stack's + format to vLLM's format. :param response_format: Llama Stack version of constrained decoding info. Can be ``None``, indicating no constraints. @@ -244,42 +198,6 @@ def _convert_sampling_params( return vllm_sampling_params -def _convert_tools( - tools: Optional[List[ToolDefinition]] = None, -) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]: - """ - Convert the list of available tools from Llama Stack's format to vLLM's - version of OpenAI's format. - """ - if tools is None: - return [] - - result = [] - for t in tools: - if isinstance(t.tool_name, BuiltinTool): - raise NotImplementedError("Built-in tools not yet implemented") - if t.parameters is None: - parameters = None - else: # if t.parameters is not None - # Convert the "required" flags to a list of required params - required_params = [k for k, v in t.parameters.items() if v.required] - parameters = { - "type": "object", # Mystery value that shows up in OpenAI docs - "properties": { - k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items() - }, - "required": required_params, - } - - function_def = vllm.entrypoints.openai.protocol.FunctionDefinition( - name=t.tool_name, description=t.description, parameters=parameters - ) - - # Every tool definition is double-boxed in a ChatCompletionToolsParam - result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def)) - return result - - ############################################################################ # Class definitions go here @@ -582,51 +500,20 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # dataclass. # Note that this dataclass has the same name as a similar dataclass in # Llama Stack. - converted_messages = [ - await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages - ] - converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) - converted_tools = _convert_tools(tools) - - # Llama will try to use built-in tools with no tool catalog, so don't enable - # tool choice unless at least one tool is enabled. - converted_tool_choice = "none" - if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0: - converted_tool_choice = "auto" - - # TODO: Figure out what to do with the tool_prompt_format argument. - # Other connectors appear to drop it quietly. - - chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest( - model=self.resolved_model_id, - messages=converted_messages, - tools=converted_tools, - tool_choice=converted_tool_choice, - stream=stream, - ) - - # vLLM's OpenAI-compatible APIs take sampling parameters as multiple - # keyword args instead of a vLLM SamplingParams object. Copy over - # all the parts that we currently convert from Llama Stack format. - for param_name in [ - "max_tokens", - "temperature", - "top_p", - "top_k", - "repetition_penalty", - ]: - setattr( - chat_completion_request, - param_name, - getattr(converted_sampling_params, param_name), + request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict( + ChatCompletionRequest( + model=self.resolved_model_id, + messages=messages, + sampling_params=sampling_params, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, ) - - # Guided decoding parameters are further broken out - if converted_sampling_params.guided_decoding is not None: - g = converted_sampling_params.guided_decoding - chat_completion_request.guided_json = g.json - chat_completion_request.guided_regex = g.regex - chat_completion_request.guided_grammar = g.grammar + ) + chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options) _info(f"Converted request: {chat_completion_request}") @@ -668,12 +555,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if len(vllm_result.choices) == 0: raise ValueError("Don't know how to convert response object without any responses") vllm_message = vllm_result.choices[0].message + vllm_finish_reason = vllm_result.choices[0].finish_reason converted_message = CompletionMessage( role=vllm_message.role, # Llama Stack API won't accept None for content field. content=("" if vllm_message.content is None else vllm_message.content), - stop_reason=_convert_finish_reason(vllm_result.choices[0].finish_reason), + stop_reason=get_stop_reason(vllm_finish_reason), tool_calls=[ ToolCall( call_id=t.id, @@ -746,7 +634,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # The result may contain multiple completions, but Llama Stack APIs # only support returning one. first_choice = parsed_chunk["choices"][0] - converted_stop_reason = _convert_finish_reason(first_choice["finish_reason"]) + converted_stop_reason = get_stop_reason(first_choice["finish_reason"]) delta_record = first_choice["delta"] if "content" in delta_record: