Update logging and route Meta Llama requests differently

This commit is contained in:
Fred Reiss 2025-01-31 14:12:35 -08:00 committed by Ashwin Bharambe
parent 24cc7a777c
commit ade413f1e3

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import datetime
import json import json
import logging import logging
import re import re
@ -17,9 +16,6 @@ import llama_models.sku_list
# fully-qualified names # fully-qualified names
import vllm.entrypoints.openai.protocol import vllm.entrypoints.openai.protocol
import vllm.sampling_params import vllm.sampling_params
############################################################################
# llama_models imports go here
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
SamplingParams, SamplingParams,
@ -31,17 +27,12 @@ from llama_models.llama3.api.datatypes import (
) )
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
############################################################################
# vLLM imports go here
#
# We deep-import the names that don't conflict with Llama Stack names # We deep-import the names that don't conflict with Llama Stack names
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_engine import BaseModelPath
############################################################################
# llama_stack imports go here
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
TextDelta, TextDelta,
@ -78,15 +69,13 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate, ModelsProtocolPrivate,
) )
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
############################################################################
# Package-local imports go here
from .config import VLLMConfig from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
############################################################################
# Constants go here
# Map from Hugging Face model architecture name to appropriate tool parser. # Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of # See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers. # available parsers.
@ -98,35 +87,15 @@ CONFIG_TYPE_TO_TOOL_PARSER = {
} }
DEFAULT_TOOL_PARSER = "pythonic" DEFAULT_TOOL_PARSER = "pythonic"
############################################################################
# Package-global variables go here
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
############################################################################ # Adjust logging parameters from Python code. This appears to be the standard way to control
# Local functions go here # logging in Llama Stack.
logger.setLevel(logging.DEBUG)
# For debugging stuff when the Llama Stack logger isn't cooperating stderr_handler = logging.StreamHandler()
_BYPASS_LOGGING = False stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s"))
logger.addHandler(stderr_handler)
def _log(msg: str, level: str):
if _BYPASS_LOGGING:
time_str = datetime.datetime.now().strftime("%H:%M:%S")
print(f"{time_str}: {msg}")
match level:
case "info":
logger.info(msg)
case "debug":
logger.debug(msg)
def _info(msg: str):
_log(msg, "info")
def _debug(msg: str):
_log(msg, "debug")
def _random_uuid_str() -> str: def _random_uuid_str() -> str:
@ -210,10 +179,6 @@ def _convert_sampling_params(
return vllm_sampling_params return vllm_sampling_params
############################################################################
# Class definitions go here
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
""" """
vLLM-based inference model adapter for Llama Stack with support for multiple models. vLLM-based inference model adapter for Llama Stack with support for multiple models.
@ -227,12 +192,11 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
resolved_model_id: str | None resolved_model_id: str | None
engine: AsyncLLMEngine | None engine: AsyncLLMEngine | None
chat: OpenAIServingChat | None chat: OpenAIServingChat | None
is_meta_llama_model: bool
def __init__(self, config: VLLMConfig): def __init__(self, config: VLLMConfig):
self.config = config self.config = config
self.engine = None logger.info(f"Config is: {self.config}")
lo
_info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_model_aliases()) self.register_helper = ModelRegistryHelper(build_model_aliases())
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
@ -242,6 +206,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.model_ids = set() self.model_ids = set()
self.engine = None self.engine = None
self.chat = None self.chat = None
self.is_meta_llama_model = False
########################################################################### ###########################################################################
# METHODS INHERITED FROM IMPLICIT BASE CLASS. # METHODS INHERITED FROM IMPLICIT BASE CLASS.
@ -264,7 +229,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
to shut down a Llama Stack server in such a way as to trigger this callback. to shut down a Llama Stack server in such a way as to trigger this callback.
""" """
_info(f"Shutting down inline vLLM inference provider {self}.") logger.info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None: if self.engine is not None:
self.engine.shutdown_background_loop() self.engine.shutdown_background_loop()
self.engine = None self.engine = None
@ -287,18 +252,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
:returns: The input ``Model`` object. It may or may not be permissible to change fields :returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object. before returning this object.
""" """
_debug(f"In register_model({model})") logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name # First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = resolve_model(model.provider_model_id) resolved_llama_model = resolve_model(model.provider_model_id)
if resolved_llama_model is not None: if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir # Load from Hugging Face repo into default local cache dir
resolved_model_id = resolved_llama_model.huggingface_repo resolved_model_id = resolved_llama_model.huggingface_repo
# Detect a geniune Meta Llama model to trigger Meta-specific preprocessing.
# Don't set self.is_meta_llama_model until we actually load the model.
is_meta_llama_model = True
else: # if resolved_llama_model is None else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader # Not a Llama model name. Pass the model id through to vLLM's loader
resolved_model_id = model.provider_model_id resolved_model_id = model.provider_model_id
is_meta_llama_model = True
_info(f"Model id {model} resolved to {resolved_model_id}") logger.info(f"Model id {model} resolved to {resolved_model_id}")
if self.resolved_model_id is not None: if self.resolved_model_id is not None:
if resolved_model_id != self.resolved_model_id: if resolved_model_id != self.resolved_model_id:
@ -312,7 +282,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.model_ids.add(model.model_id) self.model_ids.add(model.model_id)
return model return model
_info(f"Preloading model: {resolved_model_id}") self.is_meta_llama_model = is_meta_llama_model
logger.info(f"Preloading model: {resolved_model_id}")
# If we get here, this is the first time registering a model. # If we get here, this is the first time registering a model.
# Preload so that the first inference request won't time out. # Preload so that the first inference request won't time out.
@ -327,10 +298,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.engine = AsyncLLMEngine.from_engine_args(engine_args)
# vLLM currently requires the user to specify the tool parser # vLLM currently requires the user to specify the tool parser manually. To choose a tool
# manually. To choose a tool parser, we need to determine what # parser, we need to determine what model architecture is being used. For now, we infer
# model architecture is being used. For now, we infer that # that information from what config class the model uses.
# information from what config class the model uses.
low_level_model_config = self.engine.engine.get_model_config() low_level_model_config = self.engine.engine.get_model_config()
hf_config = low_level_model_config.hf_config hf_config = low_level_model_config.hf_config
hf_config_class_name = hf_config.__class__.__name__ hf_config_class_name = hf_config.__class__.__name__
@ -340,8 +310,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# No info -- choose a default so we can at least attempt tool # No info -- choose a default so we can at least attempt tool
# use. # use.
tool_parser = DEFAULT_TOOL_PARSER tool_parser = DEFAULT_TOOL_PARSER
_debug(f"{hf_config_class_name=}") logger.debug(f"{hf_config_class_name=}")
_debug(f"{tool_parser=}") logger.debug(f"{tool_parser=}")
# Wrap the lower-level engine in an OpenAI-compatible chat API # Wrap the lower-level engine in an OpenAI-compatible chat API
model_config = await self.engine.get_model_config() model_config = await self.engine.get_model_config()
@ -364,7 +334,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.resolved_model_id = resolved_model_id self.resolved_model_id = resolved_model_id
self.model_ids.add(model.model_id) self.model_ids.add(model.model_id)
_info(f"Finished preloading model: {resolved_model_id}") logger.info(f"Finished preloading model: {resolved_model_id}")
return model return model
@ -415,7 +385,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
_debug(f"{converted_sampling_params=}") logger.debug(f"{converted_sampling_params=}")
if stream: if stream:
return self._streaming_completion(content, converted_sampling_params) return self._streaming_completion(content, converted_sampling_params)
@ -429,6 +399,85 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=streaming_result.logprobs, logprobs=streaming_result.logprobs,
) )
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent], # type: ignore
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Convert to Llama Stack internal format for consistency
request = ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if self.is_meta_llama_model:
# Bypass vLLM chat templating layer for Meta Llama models, because the
# templating layer in Llama Stack currently produces better results.
logger.debug(
f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's."
)
if stream:
# return self._chat_completion_for_meta_llama_streaming(request)
pass # Use vLLM until the above method is implemented.
else:
return await self._chat_completion_for_meta_llama_non_streaming(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
logger.debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
logger.debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
async def _streaming_completion( async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]: ) -> AsyncIterator[CompletionResponseStreamChunk]:
@ -500,68 +549,48 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs, logprobs=logprobs,
) )
async def embeddings( async def _chat_completion_for_meta_llama_non_streaming(
self, self, request: ChatCompletionRequest
model_id: str, ) -> ChatCompletionResponse:
contents: List[InterleavedContent], # type: ignore """
) -> EmbeddingsResponse: Subroutine that routes chat completions for Meta Llama models through Llama Stack's
raise NotImplementedError() chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
async def chat_completion( Once vLLM's support for Meta Llama models has matured more, we should consider routing
self, Meta Llama requests through the vLLM chat completions API instead of using this method.
model_id: str, """
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass. logger.debug("Routing request through Llama Stack templates.")
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict( formatter = ChatFormat(Tokenizer.get_instance())
ChatCompletionRequest(
model=self.resolved_model_id, # Note that this function call modifies `request` in place.
messages=messages, prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
sampling_params=sampling_params,
response_format=response_format, model_id = list(self.model_ids)[0] # Any model ID will do here
tools=tools, completion_response = await self.completion(
tool_choice=tool_choice, model_id=model_id,
tool_prompt_format=tool_prompt_format, content=prompt,
stream=stream, sampling_params=request.sampling_params,
logprobs=logprobs, response_format=request.response_format,
) stream=False,
logprobs=request.logprobs,
) )
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options) if not isinstance(completion_response, CompletionResponse): # Sanity check
raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.")
_debug(f"Converted request: {chat_completion_request}") raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
vllm_result = await self.chat.create_chat_completion(chat_completion_request) )
_debug(f"Result from vLLM: {vllm_result}") return ChatCompletionResponse(
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse): completion_message=CompletionMessage(
raise ValueError(f"Error from vLLM layer: {vllm_result}") content=raw_message.content,
stop_reason=raw_message.stop_reason,
# Return type depends on "stream" argument tool_calls=raw_message.tool_calls,
if stream: ),
if not isinstance(vllm_result, AsyncGenerator): logprobs=completion_response.logprobs,
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") )
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
def _convert_non_streaming_results( def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
@ -599,12 +628,25 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# TODO: Convert logprobs # TODO: Convert logprobs
_debug(f"Converted message: {converted_message}") logger.debug(f"Converted message: {converted_message}")
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=converted_message, completion_message=converted_message,
) )
def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
logger.debug("Routing streaming request through Llama Stack templates.")
raise NotImplementedError()
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator: async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
""" """
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
@ -653,7 +695,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# Anything that is not "[DONE]" should be a JSON record # Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str) parsed_chunk = json.loads(data_str)
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs only support # The result may contain multiple completions, but Llama Stack APIs only support
# returning one. # returning one.