Update logging and route Meta Llama requests differently

This commit is contained in:
Fred Reiss 2025-01-31 14:12:35 -08:00 committed by Ashwin Bharambe
parent 24cc7a777c
commit ade413f1e3

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import datetime
import json
import logging
import re
@ -17,9 +16,6 @@ import llama_models.sku_list
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
############################################################################
# llama_models imports go here
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
SamplingParams,
@ -31,17 +27,12 @@ from llama_models.llama3.api.datatypes import (
)
from llama_models.llama3.api.tokenizer import Tokenizer
############################################################################
# vLLM imports go here
#
# We deep-import the names that don't conflict with Llama Stack names
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
############################################################################
# llama_stack imports go here
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextDelta,
@ -78,15 +69,13 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
############################################################################
# Package-local imports go here
from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
############################################################################
# Constants go here
# Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers.
@ -98,35 +87,15 @@ CONFIG_TYPE_TO_TOOL_PARSER = {
}
DEFAULT_TOOL_PARSER = "pythonic"
############################################################################
# Package-global variables go here
logger = logging.getLogger(__name__)
############################################################################
# Local functions go here
# For debugging stuff when the Llama Stack logger isn't cooperating
_BYPASS_LOGGING = False
def _log(msg: str, level: str):
if _BYPASS_LOGGING:
time_str = datetime.datetime.now().strftime("%H:%M:%S")
print(f"{time_str}: {msg}")
match level:
case "info":
logger.info(msg)
case "debug":
logger.debug(msg)
def _info(msg: str):
_log(msg, "info")
def _debug(msg: str):
_log(msg, "debug")
# Adjust logging parameters from Python code. This appears to be the standard way to control
# logging in Llama Stack.
logger.setLevel(logging.DEBUG)
stderr_handler = logging.StreamHandler()
stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s"))
logger.addHandler(stderr_handler)
def _random_uuid_str() -> str:
@ -210,10 +179,6 @@ def _convert_sampling_params(
return vllm_sampling_params
############################################################################
# Class definitions go here
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"""
vLLM-based inference model adapter for Llama Stack with support for multiple models.
@ -227,12 +192,11 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
resolved_model_id: str | None
engine: AsyncLLMEngine | None
chat: OpenAIServingChat | None
is_meta_llama_model: bool
def __init__(self, config: VLLMConfig):
self.config = config
self.engine = None
lo
_info(f"Config is: {self.config}")
logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -242,6 +206,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.model_ids = set()
self.engine = None
self.chat = None
self.is_meta_llama_model = False
###########################################################################
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
@ -264,7 +229,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
to shut down a Llama Stack server in such a way as to trigger this callback.
"""
_info(f"Shutting down inline vLLM inference provider {self}.")
logger.info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None:
self.engine.shutdown_background_loop()
self.engine = None
@ -287,18 +252,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
:returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object.
"""
_debug(f"In register_model({model})")
logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = resolve_model(model.provider_model_id)
if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir
resolved_model_id = resolved_llama_model.huggingface_repo
# Detect a geniune Meta Llama model to trigger Meta-specific preprocessing.
# Don't set self.is_meta_llama_model until we actually load the model.
is_meta_llama_model = True
else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader
resolved_model_id = model.provider_model_id
is_meta_llama_model = True
_info(f"Model id {model} resolved to {resolved_model_id}")
logger.info(f"Model id {model} resolved to {resolved_model_id}")
if self.resolved_model_id is not None:
if resolved_model_id != self.resolved_model_id:
@ -312,7 +282,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.model_ids.add(model.model_id)
return model
_info(f"Preloading model: {resolved_model_id}")
self.is_meta_llama_model = is_meta_llama_model
logger.info(f"Preloading model: {resolved_model_id}")
# If we get here, this is the first time registering a model.
# Preload so that the first inference request won't time out.
@ -327,10 +298,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
# vLLM currently requires the user to specify the tool parser
# manually. To choose a tool parser, we need to determine what
# model architecture is being used. For now, we infer that
# information from what config class the model uses.
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
# parser, we need to determine what model architecture is being used. For now, we infer
# that information from what config class the model uses.
low_level_model_config = self.engine.engine.get_model_config()
hf_config = low_level_model_config.hf_config
hf_config_class_name = hf_config.__class__.__name__
@ -340,8 +310,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# No info -- choose a default so we can at least attempt tool
# use.
tool_parser = DEFAULT_TOOL_PARSER
_debug(f"{hf_config_class_name=}")
_debug(f"{tool_parser=}")
logger.debug(f"{hf_config_class_name=}")
logger.debug(f"{tool_parser=}")
# Wrap the lower-level engine in an OpenAI-compatible chat API
model_config = await self.engine.get_model_config()
@ -364,7 +334,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.resolved_model_id = resolved_model_id
self.model_ids.add(model.model_id)
_info(f"Finished preloading model: {resolved_model_id}")
logger.info(f"Finished preloading model: {resolved_model_id}")
return model
@ -415,7 +385,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
_debug(f"{converted_sampling_params=}")
logger.debug(f"{converted_sampling_params=}")
if stream:
return self._streaming_completion(content, converted_sampling_params)
@ -429,6 +399,85 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=streaming_result.logprobs,
)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent], # type: ignore
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Convert to Llama Stack internal format for consistency
request = ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if self.is_meta_llama_model:
# Bypass vLLM chat templating layer for Meta Llama models, because the
# templating layer in Llama Stack currently produces better results.
logger.debug(
f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's."
)
if stream:
# return self._chat_completion_for_meta_llama_streaming(request)
pass # Use vLLM until the above method is implemented.
else:
return await self._chat_completion_for_meta_llama_non_streaming(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
logger.debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
logger.debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
@ -500,68 +549,48 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs,
)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent], # type: ignore
) -> EmbeddingsResponse:
raise NotImplementedError()
async def _chat_completion_for_meta_llama_non_streaming(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
logger.debug("Routing request through Llama Stack templates.")
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=False,
logprobs=request.logprobs,
)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
if not isinstance(completion_response, CompletionResponse): # Sanity check
raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.")
_debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
_debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
@ -599,12 +628,25 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# TODO: Convert logprobs
_debug(f"Converted message: {converted_message}")
logger.debug(f"Converted message: {converted_message}")
return ChatCompletionResponse(
completion_message=converted_message,
)
def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
logger.debug("Routing streaming request through Llama Stack templates.")
raise NotImplementedError()
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
"""
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
@ -653,7 +695,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str)
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs only support
# returning one.