mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Update logging and route Meta Llama requests differently
This commit is contained in:
parent
24cc7a777c
commit
ade413f1e3
1 changed files with 161 additions and 119 deletions
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
@ -17,9 +16,6 @@ import llama_models.sku_list
|
|||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
|
||||
############################################################################
|
||||
# llama_models imports go here
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
SamplingParams,
|
||||
|
@ -31,17 +27,12 @@ from llama_models.llama3.api.datatypes import (
|
|||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
############################################################################
|
||||
# vLLM imports go here
|
||||
#
|
||||
# We deep-import the names that don't conflict with Llama Stack names
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
|
||||
############################################################################
|
||||
# llama_stack imports go here
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextDelta,
|
||||
|
@ -78,15 +69,13 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
############################################################################
|
||||
# Package-local imports go here
|
||||
from .config import VLLMConfig
|
||||
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||
|
||||
############################################################################
|
||||
# Constants go here
|
||||
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
|
@ -98,35 +87,15 @@ CONFIG_TYPE_TO_TOOL_PARSER = {
|
|||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
############################################################################
|
||||
# Package-global variables go here
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
############################################################################
|
||||
# Local functions go here
|
||||
|
||||
# For debugging stuff when the Llama Stack logger isn't cooperating
|
||||
_BYPASS_LOGGING = False
|
||||
|
||||
|
||||
def _log(msg: str, level: str):
|
||||
if _BYPASS_LOGGING:
|
||||
time_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
print(f"{time_str}: {msg}")
|
||||
match level:
|
||||
case "info":
|
||||
logger.info(msg)
|
||||
case "debug":
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
def _info(msg: str):
|
||||
_log(msg, "info")
|
||||
|
||||
|
||||
def _debug(msg: str):
|
||||
_log(msg, "debug")
|
||||
# Adjust logging parameters from Python code. This appears to be the standard way to control
|
||||
# logging in Llama Stack.
|
||||
logger.setLevel(logging.DEBUG)
|
||||
stderr_handler = logging.StreamHandler()
|
||||
stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s"))
|
||||
logger.addHandler(stderr_handler)
|
||||
|
||||
|
||||
def _random_uuid_str() -> str:
|
||||
|
@ -210,10 +179,6 @@ def _convert_sampling_params(
|
|||
return vllm_sampling_params
|
||||
|
||||
|
||||
############################################################################
|
||||
# Class definitions go here
|
||||
|
||||
|
||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
@ -227,12 +192,11 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
is_meta_llama_model: bool
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
self.engine = None
|
||||
lo
|
||||
_info(f"Config is: {self.config}")
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
@ -242,6 +206,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.is_meta_llama_model = False
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
|
@ -264,7 +229,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
|
||||
to shut down a Llama Stack server in such a way as to trigger this callback.
|
||||
"""
|
||||
_info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
@ -287,18 +252,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
_debug(f"In register_model({model})")
|
||||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
resolved_model_id = resolved_llama_model.huggingface_repo
|
||||
|
||||
# Detect a geniune Meta Llama model to trigger Meta-specific preprocessing.
|
||||
# Don't set self.is_meta_llama_model until we actually load the model.
|
||||
is_meta_llama_model = True
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
resolved_model_id = model.provider_model_id
|
||||
is_meta_llama_model = True
|
||||
|
||||
_info(f"Model id {model} resolved to {resolved_model_id}")
|
||||
logger.info(f"Model id {model} resolved to {resolved_model_id}")
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if resolved_model_id != self.resolved_model_id:
|
||||
|
@ -312,7 +282,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
_info(f"Preloading model: {resolved_model_id}")
|
||||
self.is_meta_llama_model = is_meta_llama_model
|
||||
logger.info(f"Preloading model: {resolved_model_id}")
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
|
@ -327,10 +298,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser
|
||||
# manually. To choose a tool parser, we need to determine what
|
||||
# model architecture is being used. For now, we infer that
|
||||
# information from what config class the model uses.
|
||||
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
||||
# parser, we need to determine what model architecture is being used. For now, we infer
|
||||
# that information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
|
@ -340,8 +310,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
_debug(f"{hf_config_class_name=}")
|
||||
_debug(f"{tool_parser=}")
|
||||
logger.debug(f"{hf_config_class_name=}")
|
||||
logger.debug(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
|
@ -364,7 +334,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self.resolved_model_id = resolved_model_id
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
_info(f"Finished preloading model: {resolved_model_id}")
|
||||
logger.info(f"Finished preloading model: {resolved_model_id}")
|
||||
|
||||
return model
|
||||
|
||||
|
@ -415,7 +385,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
_debug(f"{converted_sampling_params=}")
|
||||
logger.debug(f"{converted_sampling_params=}")
|
||||
|
||||
if stream:
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
|
@ -429,6 +399,85 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=streaming_result.logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent], # type: ignore
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Convert to Llama Stack internal format for consistency
|
||||
request = ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if self.is_meta_llama_model:
|
||||
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
||||
# templating layer in Llama Stack currently produces better results.
|
||||
logger.debug(
|
||||
f"Routing {self.resolved_model_id} chat completion through "
|
||||
f"Llama Stack's templating layer instead of vLLM's."
|
||||
)
|
||||
if stream:
|
||||
# return self._chat_completion_for_meta_llama_streaming(request)
|
||||
pass # Use vLLM until the above method is implemented.
|
||||
else:
|
||||
return await self._chat_completion_for_meta_llama_non_streaming(request)
|
||||
|
||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
logger.debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
logger.debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
|
@ -500,68 +549,48 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent], # type: ignore
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
async def _chat_completion_for_meta_llama_non_streaming(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
|
||||
logger.debug("Routing request through Llama Stack templates.")
|
||||
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response = await self.completion(
|
||||
model_id=model_id,
|
||||
content=prompt,
|
||||
sampling_params=request.sampling_params,
|
||||
response_format=request.response_format,
|
||||
stream=False,
|
||||
logprobs=request.logprobs,
|
||||
)
|
||||
if not isinstance(completion_response, CompletionResponse): # Sanity check
|
||||
raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
completion_response.content, completion_response.stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=completion_response.logprobs,
|
||||
)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
_debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
_debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
|
@ -599,12 +628,25 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
_debug(f"Converted message: {converted_message}")
|
||||
logger.debug(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
logger.debug("Routing streaming request through Llama Stack templates.")
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
|
@ -653,7 +695,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue