diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 93f5cb56b..986eb8068 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import datetime import json import logging import re @@ -17,9 +16,6 @@ import llama_models.sku_list # fully-qualified names import vllm.entrypoints.openai.protocol import vllm.sampling_params - -############################################################################ -# llama_models imports go here from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ( SamplingParams, @@ -31,17 +27,12 @@ from llama_models.llama3.api.datatypes import ( ) from llama_models.llama3.api.tokenizer import Tokenizer -############################################################################ -# vLLM imports go here -# # We deep-import the names that don't conflict with Llama Stack names from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_engine import BaseModelPath -############################################################################ -# llama_stack imports go here from llama_stack.apis.common.content_types import ( InterleavedContent, TextDelta, @@ -78,15 +69,13 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelsProtocolPrivate, ) from llama_stack.providers.utils.inference.openai_compat import get_stop_reason +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) -############################################################################ -# Package-local imports go here from .config import VLLMConfig from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict -############################################################################ -# Constants go here - # Map from Hugging Face model architecture name to appropriate tool parser. # See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of # available parsers. @@ -98,35 +87,15 @@ CONFIG_TYPE_TO_TOOL_PARSER = { } DEFAULT_TOOL_PARSER = "pythonic" -############################################################################ -# Package-global variables go here logger = logging.getLogger(__name__) -############################################################################ -# Local functions go here - -# For debugging stuff when the Llama Stack logger isn't cooperating -_BYPASS_LOGGING = False - - -def _log(msg: str, level: str): - if _BYPASS_LOGGING: - time_str = datetime.datetime.now().strftime("%H:%M:%S") - print(f"{time_str}: {msg}") - match level: - case "info": - logger.info(msg) - case "debug": - logger.debug(msg) - - -def _info(msg: str): - _log(msg, "info") - - -def _debug(msg: str): - _log(msg, "debug") +# Adjust logging parameters from Python code. This appears to be the standard way to control +# logging in Llama Stack. +logger.setLevel(logging.DEBUG) +stderr_handler = logging.StreamHandler() +stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s")) +logger.addHandler(stderr_handler) def _random_uuid_str() -> str: @@ -210,10 +179,6 @@ def _convert_sampling_params( return vllm_sampling_params -############################################################################ -# Class definitions go here - - class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): """ vLLM-based inference model adapter for Llama Stack with support for multiple models. @@ -227,12 +192,11 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): resolved_model_id: str | None engine: AsyncLLMEngine | None chat: OpenAIServingChat | None + is_meta_llama_model: bool def __init__(self, config: VLLMConfig): self.config = config - self.engine = None - lo - _info(f"Config is: {self.config}") + logger.info(f"Config is: {self.config}") self.register_helper = ModelRegistryHelper(build_model_aliases()) self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -242,6 +206,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.model_ids = set() self.engine = None self.chat = None + self.is_meta_llama_model = False ########################################################################### # METHODS INHERITED FROM IMPLICIT BASE CLASS. @@ -264,7 +229,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how to shut down a Llama Stack server in such a way as to trigger this callback. """ - _info(f"Shutting down inline vLLM inference provider {self}.") + logger.info(f"Shutting down inline vLLM inference provider {self}.") if self.engine is not None: self.engine.shutdown_background_loop() self.engine = None @@ -287,18 +252,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): :returns: The input ``Model`` object. It may or may not be permissible to change fields before returning this object. """ - _debug(f"In register_model({model})") + logger.debug(f"In register_model({model})") # First attempt to interpret the model coordinates as a Llama model name resolved_llama_model = resolve_model(model.provider_model_id) if resolved_llama_model is not None: # Load from Hugging Face repo into default local cache dir resolved_model_id = resolved_llama_model.huggingface_repo + + # Detect a geniune Meta Llama model to trigger Meta-specific preprocessing. + # Don't set self.is_meta_llama_model until we actually load the model. + is_meta_llama_model = True else: # if resolved_llama_model is None # Not a Llama model name. Pass the model id through to vLLM's loader resolved_model_id = model.provider_model_id + is_meta_llama_model = True - _info(f"Model id {model} resolved to {resolved_model_id}") + logger.info(f"Model id {model} resolved to {resolved_model_id}") if self.resolved_model_id is not None: if resolved_model_id != self.resolved_model_id: @@ -312,7 +282,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.model_ids.add(model.model_id) return model - _info(f"Preloading model: {resolved_model_id}") + self.is_meta_llama_model = is_meta_llama_model + logger.info(f"Preloading model: {resolved_model_id}") # If we get here, this is the first time registering a model. # Preload so that the first inference request won't time out. @@ -327,10 +298,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) - # vLLM currently requires the user to specify the tool parser - # manually. To choose a tool parser, we need to determine what - # model architecture is being used. For now, we infer that - # information from what config class the model uses. + # vLLM currently requires the user to specify the tool parser manually. To choose a tool + # parser, we need to determine what model architecture is being used. For now, we infer + # that information from what config class the model uses. low_level_model_config = self.engine.engine.get_model_config() hf_config = low_level_model_config.hf_config hf_config_class_name = hf_config.__class__.__name__ @@ -340,8 +310,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # No info -- choose a default so we can at least attempt tool # use. tool_parser = DEFAULT_TOOL_PARSER - _debug(f"{hf_config_class_name=}") - _debug(f"{tool_parser=}") + logger.debug(f"{hf_config_class_name=}") + logger.debug(f"{tool_parser=}") # Wrap the lower-level engine in an OpenAI-compatible chat API model_config = await self.engine.get_model_config() @@ -364,7 +334,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.resolved_model_id = resolved_model_id self.model_ids.add(model.model_id) - _info(f"Finished preloading model: {resolved_model_id}") + logger.info(f"Finished preloading model: {resolved_model_id}") return model @@ -415,7 +385,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) - _debug(f"{converted_sampling_params=}") + logger.debug(f"{converted_sampling_params=}") if stream: return self._streaming_completion(content, converted_sampling_params) @@ -429,6 +399,85 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=streaming_result.logprobs, ) + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], # type: ignore + ) -> EmbeddingsResponse: + raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], # type: ignore + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, # type: ignore + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + if model_id not in self.model_ids: + raise ValueError( + f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" + ) + + # Convert to Llama Stack internal format for consistency + request = ChatCompletionRequest( + model=self.resolved_model_id, + messages=messages, + sampling_params=sampling_params, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + if self.is_meta_llama_model: + # Bypass vLLM chat templating layer for Meta Llama models, because the + # templating layer in Llama Stack currently produces better results. + logger.debug( + f"Routing {self.resolved_model_id} chat completion through " + f"Llama Stack's templating layer instead of vLLM's." + ) + if stream: + # return self._chat_completion_for_meta_llama_streaming(request) + pass # Use vLLM until the above method is implemented. + else: + return await self._chat_completion_for_meta_llama_non_streaming(request) + + logger.debug(f"{self.resolved_model_id} is not a Meta Llama model") + + # Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass. + # Note that this dataclass has the same name as a similar dataclass in Llama Stack. + request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request) + chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options) + + logger.debug(f"Converted request: {chat_completion_request}") + + vllm_result = await self.chat.create_chat_completion(chat_completion_request) + logger.debug(f"Result from vLLM: {vllm_result}") + if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse): + raise ValueError(f"Error from vLLM layer: {vllm_result}") + + # Return type depends on "stream" argument + if stream: + if not isinstance(vllm_result, AsyncGenerator): + raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") + # vLLM client returns a stream of strings, which need to be parsed. + # Stream comes in the form of an async generator. + return self._convert_streaming_results(vllm_result) + else: + if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse): + raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call") + return self._convert_non_streaming_results(vllm_result) + + ########################################################################### + # INTERNAL METHODS + async def _streaming_completion( self, content: str, sampling_params: vllm.SamplingParams ) -> AsyncIterator[CompletionResponseStreamChunk]: @@ -500,68 +549,48 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=logprobs, ) - async def embeddings( - self, - model_id: str, - contents: List[InterleavedContent], # type: ignore - ) -> EmbeddingsResponse: - raise NotImplementedError() + async def _chat_completion_for_meta_llama_non_streaming( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + """ + Subroutine that routes chat completions for Meta Llama models through Llama Stack's + chat template instead of using vLLM's version of that template. The Llama Stack version + of the chat template currently produces more reliable outputs. - async def chat_completion( - self, - model_id: str, - messages: List[Message], # type: ignore - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, # type: ignore - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: - if model_id not in self.model_ids: - raise ValueError( - f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" - ) + Once vLLM's support for Meta Llama models has matured more, we should consider routing + Meta Llama requests through the vLLM chat completions API instead of using this method. + """ - # Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass. - # Note that this dataclass has the same name as a similar dataclass in Llama Stack. - request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict( - ChatCompletionRequest( - model=self.resolved_model_id, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - stream=stream, - logprobs=logprobs, - ) + logger.debug("Routing request through Llama Stack templates.") + + formatter = ChatFormat(Tokenizer.get_instance()) + + # Note that this function call modifies `request` in place. + prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter) + + model_id = list(self.model_ids)[0] # Any model ID will do here + completion_response = await self.completion( + model_id=model_id, + content=prompt, + sampling_params=request.sampling_params, + response_format=request.response_format, + stream=False, + logprobs=request.logprobs, ) - chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options) + if not isinstance(completion_response, CompletionResponse): # Sanity check + raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.") - _debug(f"Converted request: {chat_completion_request}") - - vllm_result = await self.chat.create_chat_completion(chat_completion_request) - _debug(f"Result from vLLM: {vllm_result}") - if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse): - raise ValueError(f"Error from vLLM layer: {vllm_result}") - - # Return type depends on "stream" argument - if stream: - if not isinstance(vllm_result, AsyncGenerator): - raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") - # vLLM client returns a stream of strings, which need to be parsed. - # Stream comes in the form of an async generator. - return self._convert_streaming_results(vllm_result) - else: - if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse): - raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call") - return self._convert_non_streaming_results(vllm_result) - - ########################################################################### - # INTERNAL METHODS + raw_message = formatter.decode_assistant_message_from_content( + completion_response.content, completion_response.stop_reason + ) + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=raw_message.content, + stop_reason=raw_message.stop_reason, + tool_calls=raw_message.tool_calls, + ), + logprobs=completion_response.logprobs, + ) def _convert_non_streaming_results( self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse @@ -599,12 +628,25 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # TODO: Convert logprobs - _debug(f"Converted message: {converted_message}") + logger.debug(f"Converted message: {converted_message}") return ChatCompletionResponse( completion_message=converted_message, ) + def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator: + """ + Subroutine that routes chat completions for Meta Llama models through Llama Stack's + chat template instead of using vLLM's version of that template. The Llama Stack version + of the chat template currently produces more reliable outputs. + + Once vLLM's support for Meta Llama models has matured more, we should consider routing + Meta Llama requests through the vLLM chat completions API instead of using this method. + """ + logger.debug("Routing streaming request through Llama Stack templates.") + + raise NotImplementedError() + async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator: """ Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible @@ -653,7 +695,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # Anything that is not "[DONE]" should be a JSON record parsed_chunk = json.loads(data_str) - _debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") + logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") # The result may contain multiple completions, but Llama Stack APIs only support # returning one.