mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? TLDR: Changes needed to get 100% passing tests for OpenAI API verification tests when run against Llama Stack with the `together`, `fireworks`, and `openai` providers. And `groq` is better than before, at 88% passing. This cleans up the OpenAI API support for image message types (specifically `image_url` types) and handling of the `response_format` chat completion parameter. Both of these required a few more Pydantic model definitions in our Inference API, just to move from the not-quite-right stubs I had in place to something fleshed out to match the actual OpenAI API specs. As part of testing this, I also found and fixed a bug in the litellm implementation of openai_completion and openai_chat_completion, so the providers based on those should actually be working now. The method `prepare_openai_completion_params` in `llama_stack/providers/utils/inference/openai_compat.py` was improved to actually recursively clean up input parameters, including handling of lists, dicts, and dumping of Pydantic models to dicts. These changes were required to get to 100% passing tests on the OpenAI API verification against the `openai` provider. With the above, the together.ai provider was passing as well as it is without Llama Stack. But, since we have Llama Stack in the middle, I took the opportunity to clean up the together.ai provider so that it now also passes the OpenAI API spec tests we have at 100%. That means together.ai is now passing our verification test better when using an OpenAI client talking to Llama Stack than it is when hitting together.ai directly, without Llama Stack in the middle. And, another round of work for Fireworks to improve translation of incoming OpenAI chat completion requests to Llama Stack chat completion requests gets the fireworks provider passing at 100%. The server-side fireworks.ai tool calling support with OpenAI chat completions and Llama 4 models isn't great yet, but by pointing the OpenAI clients at Llama Stack's API we can clean things up and get everything working as expected for Llama 4 models. ## Test Plan ### OpenAI API Verification Tests I ran the OpenAI API verification tests as below and 100% of the tests passed. First, start a Llama Stack server that runs the `openai` provider with the `gpt-4o` and `gpt-4o-mini` models deployed. There's not a template setup to do this out of the box, so I added a `tests/verifications/openai-api-verification-run.yaml` to do this. First, ensure you have the necessary API key environment variables set: ``` export TOGETHER_API_KEY="..." export FIREWORKS_API_KEY="..." export OPENAI_API_KEY="..." ``` Then, run a Llama Stack server that serves up all these providers: ``` llama stack run \ --image-type venv \ tests/verifications/openai-api-verification-run.yaml ``` Finally, generate a new verification report against all these providers, both with and without the Llama Stack server in the middle. ``` python tests/verifications/generate_report.py \ --run-tests \ --provider \ together \ fireworks \ groq \ openai \ together-llama-stack \ fireworks-llama-stack \ groq-llama-stack \ openai-llama-stack ``` You'll see that most of the configurations with Llama Stack in the middle now pass at 100%, even though some of them do not pass at 100% when hitting the backend provider's API directly with an OpenAI client. ### OpenAI Completion Integration Tests with vLLM: I also ran the smaller `test_openai_completion.py` test suite (that's not yet merged with the verification tests) on multiple of the providers, since I had to adjust the method signature of openai_chat_completion a bit and thus had to touch lots of these providers to match. Here's the tests I ran there, all passing: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` ### OpenAI Completion Integration Tests with ollama ``` INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0" ``` ### OpenAI Completion Integration Tests with together.ai ``` INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" llama stack build --template together --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct-Turbo" ``` ### OpenAI Completion Integration Tests with fireworks.ai ``` INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" llama stack build --template fireworks --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.1-8B-Instruct" --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
800 lines
35 KiB
Python
800 lines
35 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import re
|
|
import uuid
|
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|
|
|
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
|
# fully-qualified names
|
|
import vllm.entrypoints.openai.protocol
|
|
import vllm.sampling_params
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextDelta,
|
|
ToolCallDelta,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
TokenLogProbs,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
TopKSamplingStrategy,
|
|
TopPSamplingStrategy,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama import sku_list
|
|
from llama_stack.models.llama.datatypes import (
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
ModelsProtocolPrivate,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
OpenAICompletionToLlamaStackMixin,
|
|
get_stop_reason,
|
|
process_chat_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
)
|
|
|
|
from .config import VLLMConfig
|
|
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
|
|
|
# Map from Hugging Face model architecture name to appropriate tool parser.
|
|
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
|
# available parsers.
|
|
# TODO: Expand this list
|
|
CONFIG_TYPE_TO_TOOL_PARSER = {
|
|
"GraniteConfig": "granite",
|
|
"MllamaConfig": "llama3_json",
|
|
"LlamaConfig": "llama3_json",
|
|
}
|
|
DEFAULT_TOOL_PARSER = "pythonic"
|
|
|
|
|
|
logger = get_logger(__name__, category="inference")
|
|
|
|
|
|
def _random_uuid_str() -> str:
|
|
return str(uuid.uuid4().hex)
|
|
|
|
|
|
def _response_format_to_guided_decoding_params(
|
|
response_format: Optional[ResponseFormat], # type: ignore
|
|
) -> vllm.sampling_params.GuidedDecodingParams:
|
|
"""
|
|
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
|
|
|
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
|
indicating no constraints.
|
|
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
|
"""
|
|
if response_format is None:
|
|
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
|
# value that crashes the executor on some code paths. Use ``None`` instead.
|
|
return None
|
|
|
|
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
|
# Translate the types that exist and detect if Llama Stack adds new ones.
|
|
if isinstance(response_format, JsonSchemaResponseFormat):
|
|
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
|
elif isinstance(response_format, GrammarResponseFormat):
|
|
# BNF grammar.
|
|
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
|
# representation of the grammar.
|
|
raise TypeError(
|
|
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
|
"reference implementation does not implement it."
|
|
)
|
|
else:
|
|
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
|
|
|
|
|
def _convert_sampling_params(
|
|
sampling_params: Optional[SamplingParams],
|
|
response_format: Optional[ResponseFormat], # type: ignore
|
|
log_prob_config: Optional[LogProbConfig],
|
|
) -> vllm.SamplingParams:
|
|
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
|
format."""
|
|
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
|
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if log_prob_config is None:
|
|
log_prob_config = LogProbConfig()
|
|
|
|
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
|
if sampling_params.strategy.top_k == 0:
|
|
# vLLM treats "k" differently for top-k sampling
|
|
vllm_top_k = -1
|
|
else:
|
|
vllm_top_k = sampling_params.strategy.top_k
|
|
else:
|
|
vllm_top_k = -1
|
|
|
|
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
|
vllm_top_p = sampling_params.strategy.top_p
|
|
# Llama Stack only allows temperature with top-P.
|
|
vllm_temperature = sampling_params.strategy.temperature
|
|
else:
|
|
vllm_top_p = 1.0
|
|
vllm_temperature = 0.0
|
|
|
|
# vLLM allows top-p and top-k at the same time.
|
|
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
|
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
|
temperature=vllm_temperature,
|
|
top_p=vllm_top_p,
|
|
top_k=vllm_top_k,
|
|
repetition_penalty=sampling_params.repetition_penalty,
|
|
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
|
logprobs=log_prob_config.top_k,
|
|
)
|
|
return vllm_sampling_params
|
|
|
|
|
|
class VLLMInferenceImpl(
|
|
Inference,
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
OpenAICompletionToLlamaStackMixin,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
"""
|
|
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
|
|
|
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
|
"""
|
|
|
|
config: VLLMConfig
|
|
register_helper: ModelRegistryHelper
|
|
model_ids: set[str]
|
|
resolved_model_id: str | None
|
|
engine: AsyncLLMEngine | None
|
|
chat: OpenAIServingChat | None
|
|
is_meta_llama_model: bool
|
|
|
|
def __init__(self, config: VLLMConfig):
|
|
self.config = config
|
|
logger.info(f"Config is: {self.config}")
|
|
|
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
# The following are initialized when paths are bound to this provider
|
|
self.resolved_model_id = None
|
|
self.model_ids = set()
|
|
self.engine = None
|
|
self.chat = None
|
|
self.is_meta_llama_model = False
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
|
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
|
|
|
async def initialize(self) -> None:
|
|
"""
|
|
Callback that is invoked through many levels of indirection during provider class
|
|
instantiation, sometime after when __init__() is called and before any model registration
|
|
methods or methods connected to a REST API are called.
|
|
|
|
It's not clear what assumptions the class can make about the platform's initialization
|
|
state here that can't be made during __init__(), and vLLM can't be started until we know
|
|
what model it's supposed to be serving, so nothing happens here currently.
|
|
"""
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
|
if self.engine is not None:
|
|
self.engine.shutdown_background_loop()
|
|
self.engine = None
|
|
self.chat = None
|
|
self.model_ids = set()
|
|
self.resolved_model_id = None
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
|
|
|
# Note that the return type of the superclass method is WRONG
|
|
async def register_model(self, model: Model) -> Model:
|
|
"""
|
|
Callback that is called when the server associates an inference endpoint with an
|
|
inference provider.
|
|
|
|
:param model: Object that encapsulates parameters necessary for identifying a specific
|
|
LLM.
|
|
|
|
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
|
before returning this object.
|
|
"""
|
|
logger.debug(f"In register_model({model})")
|
|
|
|
# First attempt to interpret the model coordinates as a Llama model name
|
|
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
|
if resolved_llama_model is not None:
|
|
# Load from Hugging Face repo into default local cache dir
|
|
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
|
|
|
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
|
# Don't set self.is_meta_llama_model until we actually load the model.
|
|
is_meta_llama_model = True
|
|
else: # if resolved_llama_model is None
|
|
# Not a Llama model name. Pass the model id through to vLLM's loader
|
|
model_id_for_vllm = model.provider_model_id
|
|
is_meta_llama_model = False
|
|
|
|
if self.resolved_model_id is not None:
|
|
if model_id_for_vllm != self.resolved_model_id:
|
|
raise ValueError(
|
|
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
|
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
|
f"copies of the provider instead."
|
|
)
|
|
else:
|
|
# Model already loaded
|
|
logger.info(
|
|
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
|
)
|
|
self.model_ids.add(model.model_id)
|
|
return model
|
|
|
|
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
|
if is_meta_llama_model:
|
|
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
|
self.is_meta_llama_model = is_meta_llama_model
|
|
|
|
# If we get here, this is the first time registering a model.
|
|
# Preload so that the first inference request won't time out.
|
|
engine_args = AsyncEngineArgs(
|
|
model=model_id_for_vllm,
|
|
tokenizer=model_id_for_vllm,
|
|
tensor_parallel_size=self.config.tensor_parallel_size,
|
|
enforce_eager=self.config.enforce_eager,
|
|
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
|
max_num_seqs=self.config.max_num_seqs,
|
|
max_model_len=self.config.max_model_len,
|
|
)
|
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
|
# parser, we need to determine what model architecture is being used. For now, we infer
|
|
# that information from what config class the model uses.
|
|
low_level_model_config = self.engine.engine.get_model_config()
|
|
hf_config = low_level_model_config.hf_config
|
|
hf_config_class_name = hf_config.__class__.__name__
|
|
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
|
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
|
else:
|
|
# No info -- choose a default so we can at least attempt tool
|
|
# use.
|
|
tool_parser = DEFAULT_TOOL_PARSER
|
|
logger.debug(f"{hf_config_class_name=}")
|
|
logger.debug(f"{tool_parser=}")
|
|
|
|
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
|
model_config = await self.engine.get_model_config()
|
|
self.chat = OpenAIServingChat(
|
|
engine_client=self.engine,
|
|
model_config=model_config,
|
|
models=OpenAIServingModels(
|
|
engine_client=self.engine,
|
|
model_config=model_config,
|
|
base_model_paths=[
|
|
# The layer below us will only see resolved model IDs
|
|
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
|
],
|
|
),
|
|
response_role="assistant",
|
|
request_logger=None, # Use default logging
|
|
chat_template=None, # Use default template from model checkpoint
|
|
enable_auto_tools=True,
|
|
tool_parser=tool_parser,
|
|
chat_template_content_format="auto",
|
|
)
|
|
self.resolved_model_id = model_id_for_vllm
|
|
self.model_ids.add(model.model_id)
|
|
|
|
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
|
|
|
return model
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
"""
|
|
Callback that is called when the server removes an inference endpoint from an inference
|
|
provider.
|
|
|
|
:param model_id: The same external ID that the higher layers of the stack previously passed
|
|
to :func:`register_model()`
|
|
"""
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
|
)
|
|
self.model_ids.remove(model_id)
|
|
|
|
if len(self.model_ids) == 0:
|
|
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
|
# resources.
|
|
# Note that this operation may cause in-flight chat completion requests on the
|
|
# now-unregistered model to return errors.
|
|
self.resolved_model_id = None
|
|
self.chat = None
|
|
self.engine.shutdown_background_loop()
|
|
self.engine = None
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM Inference INTERFACE
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
|
)
|
|
if not isinstance(content, str):
|
|
raise NotImplementedError("Multimodal input not currently supported")
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
|
|
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
|
|
|
logger.debug(f"{converted_sampling_params=}")
|
|
|
|
if stream:
|
|
return self._streaming_completion(content, converted_sampling_params)
|
|
else:
|
|
streaming_result = None
|
|
async for _ in self._streaming_completion(content, converted_sampling_params):
|
|
pass
|
|
return CompletionResponse(
|
|
content=streaming_result.delta,
|
|
stop_reason=streaming_result.stop_reason,
|
|
logprobs=streaming_result.logprobs,
|
|
)
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
output_dimension: Optional[int] = None,
|
|
task_type: Optional[EmbeddingTaskType] = None,
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message], # type: ignore
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None, # type: ignore
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
|
sampling_params = sampling_params or SamplingParams()
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
|
)
|
|
|
|
# Convert to Llama Stack internal format for consistency
|
|
request = ChatCompletionRequest(
|
|
model=self.resolved_model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
if self.is_meta_llama_model:
|
|
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
|
# templating layer in Llama Stack currently produces better results.
|
|
logger.debug(
|
|
f"Routing {self.resolved_model_id} chat completion through "
|
|
f"Llama Stack's templating layer instead of vLLM's."
|
|
)
|
|
return await self._chat_completion_for_meta_llama(request)
|
|
|
|
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
|
|
|
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
|
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
|
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
|
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
|
|
|
logger.debug(f"Converted request: {chat_completion_request}")
|
|
|
|
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
|
logger.debug(f"Result from vLLM: {vllm_result}")
|
|
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
|
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
|
|
|
# Return type depends on "stream" argument
|
|
if stream:
|
|
if not isinstance(vllm_result, AsyncGenerator):
|
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
|
# vLLM client returns a stream of strings, which need to be parsed.
|
|
# Stream comes in the form of an async generator.
|
|
return self._convert_streaming_results(vllm_result)
|
|
else:
|
|
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
|
return self._convert_non_streaming_results(vllm_result)
|
|
|
|
###########################################################################
|
|
# INTERNAL METHODS
|
|
|
|
async def _streaming_completion(
|
|
self, content: str, sampling_params: vllm.SamplingParams
|
|
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
|
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
|
that arguments have been validated upstream.
|
|
|
|
:param content: Must be a string
|
|
:param sampling_params: Paramters from public API's ``response_format``
|
|
and ``sampling_params`` arguments, converted to VLLM format
|
|
"""
|
|
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
|
# layer, because doing so simplifies the code here.
|
|
|
|
# The vLLM engine requires a unique identifier for each call to generate()
|
|
request_id = _random_uuid_str()
|
|
|
|
# The vLLM generate() API is streaming-only and returns an async generator.
|
|
# The generator returns objects of type vllm.RequestOutput.
|
|
results_generator = self.engine.generate(content, sampling_params, request_id)
|
|
|
|
# Need to know the model's EOS token ID for the conversion code below.
|
|
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
|
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
|
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
|
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
|
llm_engine = self.engine.engine
|
|
tokenizer_group = llm_engine.tokenizer
|
|
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
|
|
|
request_output: vllm.RequestOutput = None
|
|
async for request_output in results_generator:
|
|
# Check for weird inference failures
|
|
if request_output.outputs is None or len(request_output.outputs) == 0:
|
|
# This case also should never happen
|
|
raise ValueError("Inference produced empty result")
|
|
|
|
# If we get here, then request_output contains the final output of the generate() call.
|
|
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
|
# us to return one.
|
|
output: vllm.CompletionOutput = request_output.outputs[0]
|
|
completion_string = output.text
|
|
|
|
# Convert logprobs from vLLM's format to Llama Stack's format
|
|
logprobs = [
|
|
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
|
for logprob_dict in output.logprobs
|
|
]
|
|
|
|
# The final output chunk should be labeled with the reason that the overall generate()
|
|
# call completed.
|
|
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
|
if output.stop_reason is None:
|
|
stop_reason = None # Still going
|
|
elif output.stop_reason == "stop":
|
|
stop_reason = StopReason.end_of_turn
|
|
elif output.stop_reason == "length":
|
|
stop_reason = StopReason.out_of_tokens
|
|
elif isinstance(output.stop_reason, int):
|
|
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
|
# will return the token ID of the EOS token in the stop_reason field.
|
|
stop_reason = StopReason.end_of_turn
|
|
else:
|
|
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
|
|
|
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
|
# some reason.
|
|
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
|
stop_reason = StopReason.end_of_message
|
|
|
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
|
|
|
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
|
# provide one if it runs out of tokens.
|
|
if stop_reason is None:
|
|
yield CompletionResponseStreamChunk(
|
|
delta=completion_string,
|
|
stop_reason=StopReason.out_of_tokens,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
def _convert_non_streaming_results(
|
|
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
|
equivalent Llama Stack object.
|
|
|
|
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
|
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
|
the fields that aren't currently present in the Llama Stack dataclass.
|
|
"""
|
|
|
|
# There may be multiple responses, but we can only pass through the first one.
|
|
if len(vllm_result.choices) == 0:
|
|
raise ValueError("Don't know how to convert response object without any responses")
|
|
vllm_message = vllm_result.choices[0].message
|
|
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
|
|
|
converted_message = CompletionMessage(
|
|
role=vllm_message.role,
|
|
# Llama Stack API won't accept None for content field.
|
|
content=("" if vllm_message.content is None else vllm_message.content),
|
|
stop_reason=get_stop_reason(vllm_finish_reason),
|
|
tool_calls=[
|
|
ToolCall(
|
|
call_id=t.id,
|
|
tool_name=t.function.name,
|
|
# vLLM function args come back as a string. Llama Stack expects JSON.
|
|
arguments=json.loads(t.function.arguments),
|
|
arguments_json=t.function.arguments,
|
|
)
|
|
for t in vllm_message.tool_calls
|
|
],
|
|
)
|
|
|
|
# TODO: Convert logprobs
|
|
|
|
logger.debug(f"Converted message: {converted_message}")
|
|
|
|
return ChatCompletionResponse(
|
|
completion_message=converted_message,
|
|
)
|
|
|
|
async def _chat_completion_for_meta_llama(
|
|
self, request: ChatCompletionRequest
|
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
|
"""
|
|
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
|
chat template instead of using vLLM's version of that template. The Llama Stack version
|
|
of the chat template currently produces more reliable outputs.
|
|
|
|
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
|
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
|
"""
|
|
formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
# Note that this function call modifies `request` in place.
|
|
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
|
|
|
model_id = list(self.model_ids)[0] # Any model ID will do here
|
|
completion_response_or_iterator = await self.completion(
|
|
model_id=model_id,
|
|
content=prompt,
|
|
sampling_params=request.sampling_params,
|
|
response_format=request.response_format,
|
|
stream=request.stream,
|
|
logprobs=request.logprobs,
|
|
)
|
|
|
|
if request.stream:
|
|
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
|
raise TypeError(
|
|
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
|
)
|
|
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
|
|
|
# elsif not request.stream:
|
|
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
|
raise TypeError(
|
|
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
|
)
|
|
completion_response: CompletionResponse = completion_response_or_iterator
|
|
raw_message = formatter.decode_assistant_message_from_content(
|
|
completion_response.content, completion_response.stop_reason
|
|
)
|
|
return ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=raw_message.content,
|
|
stop_reason=raw_message.stop_reason,
|
|
tool_calls=raw_message.tool_calls,
|
|
),
|
|
logprobs=completion_response.logprobs,
|
|
)
|
|
|
|
async def _chat_completion_for_meta_llama_streaming(
|
|
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
|
) -> AsyncIterator:
|
|
"""
|
|
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
|
method to keep asyncio happy.
|
|
"""
|
|
|
|
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
|
async def _generate_and_convert_to_openai_compat():
|
|
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
|
last_text_len = 0
|
|
async for chunk in results_iterator:
|
|
if chunk.stop_reason == StopReason.end_of_turn:
|
|
finish_reason = "stop"
|
|
elif chunk.stop_reason == StopReason.end_of_message:
|
|
finish_reason = "eos"
|
|
elif chunk.stop_reason == StopReason.out_of_tokens:
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = None
|
|
|
|
# Convert delta back to an actual delta
|
|
text_delta = chunk.delta[last_text_len:]
|
|
last_text_len = len(chunk.delta)
|
|
|
|
logger.debug(f"{text_delta=}; {finish_reason=}")
|
|
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
logger.debug(f"Returning chunk: {chunk}")
|
|
yield chunk
|
|
|
|
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
|
"""
|
|
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
|
API into a second async iterator that returns Llama Stack objects.
|
|
|
|
:param vllm_result: Stream of strings that need to be parsed
|
|
"""
|
|
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
|
# those chunks and output them at the end.
|
|
# This data structure holds the current set of partial tool calls.
|
|
index_to_tool_call: Dict[int, Dict] = dict()
|
|
|
|
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
|
# simplify logic below
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=None,
|
|
)
|
|
)
|
|
|
|
converted_stop_reason = None
|
|
async for chunk_str in vllm_result:
|
|
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
|
# end with "\n\n".
|
|
_prefix = "data: "
|
|
_suffix = "\n\n"
|
|
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
|
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
|
|
|
# In between the "data: " and newlines is an event record
|
|
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
|
|
|
# The end of the stream is indicated with "[DONE]"
|
|
if data_str == "[DONE]":
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
return
|
|
|
|
# Anything that is not "[DONE]" should be a JSON record
|
|
parsed_chunk = json.loads(data_str)
|
|
|
|
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
|
|
|
# The result may contain multiple completions, but Llama Stack APIs only support
|
|
# returning one.
|
|
first_choice = parsed_chunk["choices"][0]
|
|
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
|
delta_record = first_choice["delta"]
|
|
|
|
if "content" in delta_record:
|
|
# Text delta
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=TextDelta(text=delta_record["content"]),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
elif "tool_calls" in delta_record:
|
|
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
|
# calls, so buffer until we get a "tool calls" stop reason
|
|
for tc in delta_record["tool_calls"]:
|
|
index = tc["index"]
|
|
if index not in index_to_tool_call:
|
|
# First time this tool call is showing up
|
|
index_to_tool_call[index] = dict()
|
|
tool_call = index_to_tool_call[index]
|
|
if "id" in tc:
|
|
tool_call["call_id"] = tc["id"]
|
|
if "function" in tc:
|
|
if "name" in tc["function"]:
|
|
tool_call["tool_name"] = tc["function"]["name"]
|
|
if "arguments" in tc["function"]:
|
|
# Arguments comes in as pieces of a string
|
|
if "arguments_str" not in tool_call:
|
|
tool_call["arguments_str"] = ""
|
|
tool_call["arguments_str"] += tc["function"]["arguments"]
|
|
else:
|
|
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
|
|
|
if first_choice["finish_reason"] == "tool_calls":
|
|
# Special OpenAI code for "tool calls complete".
|
|
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
|
# call.
|
|
for tool_call_record in index_to_tool_call.values():
|
|
# Arguments come in as a string. Parse the completed string.
|
|
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
|
del tool_call_record["arguments_str"]
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
|
|
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
|
# normally.
|
|
raise ValueError("vLLM event stream ended without [DONE] message.")
|