mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This stubs in some OpenAI server-side compatibility with three new endpoints: /v1/openai/v1/models /v1/openai/v1/completions /v1/openai/v1/chat/completions This gives common inference apps using OpenAI clients the ability to talk to Llama Stack using an endpoint like http://localhost:8321/v1/openai/v1 . The two "v1" instances in there isn't awesome, but the thinking is that Llama Stack's API is v1 and then our OpenAI compatibility layer is compatible with OpenAI V1. And, some OpenAI clients implicitly assume the URL ends with "v1", so this gives maximum compatibility. The openai models endpoint is implemented in the routing layer, and just returns all the models Llama Stack knows about. The following providers should be working with the new OpenAI completions and chat/completions API: * remote::anthropic (untested) * remote::cerebras-openai-compat (untested) * remote::fireworks (tested) * remote::fireworks-openai-compat (untested) * remote::gemini (untested) * remote::groq-openai-compat (untested) * remote::nvidia (tested) * remote::ollama (tested) * remote::openai (untested) * remote::passthrough (untested) * remote::sambanova-openai-compat (untested) * remote::together (tested) * remote::together-openai-compat (untested) * remote::vllm (tested) The goal to support this for every inference provider - proxying directly to the provider's OpenAI endpoint for OpenAI-compatible providers. For providers that don't have an OpenAI-compatible API, we'll add a mixin to translate incoming OpenAI requests to Llama Stack inference requests and translate the Llama Stack inference responses to OpenAI responses. This is related to #1817 but is a bit larger in scope than just chat completions, as I have real use-cases that need the older completions API as well. ## Test Plan ### vLLM ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` ### ollama ``` INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0" ``` ## Documentation Run a Llama Stack distribution that uses one of the providers mentioned in the list above. Then, use your favorite OpenAI client to send completion or chat completion requests with the base_url set to http://localhost:8321/v1/openai/v1 . Replace "localhost:8321" with the host and port of your Llama Stack server, if different. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
800 lines
35 KiB
Python
800 lines
35 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import re
|
|
import uuid
|
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|
|
|
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
|
# fully-qualified names
|
|
import vllm.entrypoints.openai.protocol
|
|
import vllm.sampling_params
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextDelta,
|
|
ToolCallDelta,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
TokenLogProbs,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
TopKSamplingStrategy,
|
|
TopPSamplingStrategy,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama import sku_list
|
|
from llama_stack.models.llama.datatypes import (
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
ModelsProtocolPrivate,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAIChatCompletionUnsupportedMixin,
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
OpenAICompletionUnsupportedMixin,
|
|
get_stop_reason,
|
|
process_chat_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
)
|
|
|
|
from .config import VLLMConfig
|
|
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
|
|
|
# Map from Hugging Face model architecture name to appropriate tool parser.
|
|
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
|
# available parsers.
|
|
# TODO: Expand this list
|
|
CONFIG_TYPE_TO_TOOL_PARSER = {
|
|
"GraniteConfig": "granite",
|
|
"MllamaConfig": "llama3_json",
|
|
"LlamaConfig": "llama3_json",
|
|
}
|
|
DEFAULT_TOOL_PARSER = "pythonic"
|
|
|
|
|
|
logger = get_logger(__name__, category="inference")
|
|
|
|
|
|
def _random_uuid_str() -> str:
|
|
return str(uuid.uuid4().hex)
|
|
|
|
|
|
def _response_format_to_guided_decoding_params(
|
|
response_format: Optional[ResponseFormat], # type: ignore
|
|
) -> vllm.sampling_params.GuidedDecodingParams:
|
|
"""
|
|
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
|
|
|
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
|
indicating no constraints.
|
|
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
|
"""
|
|
if response_format is None:
|
|
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
|
# value that crashes the executor on some code paths. Use ``None`` instead.
|
|
return None
|
|
|
|
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
|
# Translate the types that exist and detect if Llama Stack adds new ones.
|
|
if isinstance(response_format, JsonSchemaResponseFormat):
|
|
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
|
elif isinstance(response_format, GrammarResponseFormat):
|
|
# BNF grammar.
|
|
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
|
# representation of the grammar.
|
|
raise TypeError(
|
|
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
|
"reference implementation does not implement it."
|
|
)
|
|
else:
|
|
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
|
|
|
|
|
def _convert_sampling_params(
|
|
sampling_params: Optional[SamplingParams],
|
|
response_format: Optional[ResponseFormat], # type: ignore
|
|
log_prob_config: Optional[LogProbConfig],
|
|
) -> vllm.SamplingParams:
|
|
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
|
format."""
|
|
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
|
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if log_prob_config is None:
|
|
log_prob_config = LogProbConfig()
|
|
|
|
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
|
if sampling_params.strategy.top_k == 0:
|
|
# vLLM treats "k" differently for top-k sampling
|
|
vllm_top_k = -1
|
|
else:
|
|
vllm_top_k = sampling_params.strategy.top_k
|
|
else:
|
|
vllm_top_k = -1
|
|
|
|
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
|
vllm_top_p = sampling_params.strategy.top_p
|
|
# Llama Stack only allows temperature with top-P.
|
|
vllm_temperature = sampling_params.strategy.temperature
|
|
else:
|
|
vllm_top_p = 1.0
|
|
vllm_temperature = 0.0
|
|
|
|
# vLLM allows top-p and top-k at the same time.
|
|
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
|
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
|
temperature=vllm_temperature,
|
|
top_p=vllm_top_p,
|
|
top_k=vllm_top_k,
|
|
repetition_penalty=sampling_params.repetition_penalty,
|
|
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
|
logprobs=log_prob_config.top_k,
|
|
)
|
|
return vllm_sampling_params
|
|
|
|
|
|
class VLLMInferenceImpl(
|
|
Inference,
|
|
OpenAIChatCompletionUnsupportedMixin,
|
|
OpenAICompletionUnsupportedMixin,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
"""
|
|
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
|
|
|
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
|
"""
|
|
|
|
config: VLLMConfig
|
|
register_helper: ModelRegistryHelper
|
|
model_ids: set[str]
|
|
resolved_model_id: str | None
|
|
engine: AsyncLLMEngine | None
|
|
chat: OpenAIServingChat | None
|
|
is_meta_llama_model: bool
|
|
|
|
def __init__(self, config: VLLMConfig):
|
|
self.config = config
|
|
logger.info(f"Config is: {self.config}")
|
|
|
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
# The following are initialized when paths are bound to this provider
|
|
self.resolved_model_id = None
|
|
self.model_ids = set()
|
|
self.engine = None
|
|
self.chat = None
|
|
self.is_meta_llama_model = False
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
|
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
|
|
|
async def initialize(self) -> None:
|
|
"""
|
|
Callback that is invoked through many levels of indirection during provider class
|
|
instantiation, sometime after when __init__() is called and before any model registration
|
|
methods or methods connected to a REST API are called.
|
|
|
|
It's not clear what assumptions the class can make about the platform's initialization
|
|
state here that can't be made during __init__(), and vLLM can't be started until we know
|
|
what model it's supposed to be serving, so nothing happens here currently.
|
|
"""
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
|
if self.engine is not None:
|
|
self.engine.shutdown_background_loop()
|
|
self.engine = None
|
|
self.chat = None
|
|
self.model_ids = set()
|
|
self.resolved_model_id = None
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
|
|
|
# Note that the return type of the superclass method is WRONG
|
|
async def register_model(self, model: Model) -> Model:
|
|
"""
|
|
Callback that is called when the server associates an inference endpoint with an
|
|
inference provider.
|
|
|
|
:param model: Object that encapsulates parameters necessary for identifying a specific
|
|
LLM.
|
|
|
|
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
|
before returning this object.
|
|
"""
|
|
logger.debug(f"In register_model({model})")
|
|
|
|
# First attempt to interpret the model coordinates as a Llama model name
|
|
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
|
if resolved_llama_model is not None:
|
|
# Load from Hugging Face repo into default local cache dir
|
|
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
|
|
|
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
|
# Don't set self.is_meta_llama_model until we actually load the model.
|
|
is_meta_llama_model = True
|
|
else: # if resolved_llama_model is None
|
|
# Not a Llama model name. Pass the model id through to vLLM's loader
|
|
model_id_for_vllm = model.provider_model_id
|
|
is_meta_llama_model = False
|
|
|
|
if self.resolved_model_id is not None:
|
|
if model_id_for_vllm != self.resolved_model_id:
|
|
raise ValueError(
|
|
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
|
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
|
f"copies of the provider instead."
|
|
)
|
|
else:
|
|
# Model already loaded
|
|
logger.info(
|
|
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
|
)
|
|
self.model_ids.add(model.model_id)
|
|
return model
|
|
|
|
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
|
if is_meta_llama_model:
|
|
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
|
self.is_meta_llama_model = is_meta_llama_model
|
|
|
|
# If we get here, this is the first time registering a model.
|
|
# Preload so that the first inference request won't time out.
|
|
engine_args = AsyncEngineArgs(
|
|
model=model_id_for_vllm,
|
|
tokenizer=model_id_for_vllm,
|
|
tensor_parallel_size=self.config.tensor_parallel_size,
|
|
enforce_eager=self.config.enforce_eager,
|
|
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
|
max_num_seqs=self.config.max_num_seqs,
|
|
max_model_len=self.config.max_model_len,
|
|
)
|
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
|
# parser, we need to determine what model architecture is being used. For now, we infer
|
|
# that information from what config class the model uses.
|
|
low_level_model_config = self.engine.engine.get_model_config()
|
|
hf_config = low_level_model_config.hf_config
|
|
hf_config_class_name = hf_config.__class__.__name__
|
|
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
|
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
|
else:
|
|
# No info -- choose a default so we can at least attempt tool
|
|
# use.
|
|
tool_parser = DEFAULT_TOOL_PARSER
|
|
logger.debug(f"{hf_config_class_name=}")
|
|
logger.debug(f"{tool_parser=}")
|
|
|
|
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
|
model_config = await self.engine.get_model_config()
|
|
self.chat = OpenAIServingChat(
|
|
engine_client=self.engine,
|
|
model_config=model_config,
|
|
models=OpenAIServingModels(
|
|
engine_client=self.engine,
|
|
model_config=model_config,
|
|
base_model_paths=[
|
|
# The layer below us will only see resolved model IDs
|
|
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
|
],
|
|
),
|
|
response_role="assistant",
|
|
request_logger=None, # Use default logging
|
|
chat_template=None, # Use default template from model checkpoint
|
|
enable_auto_tools=True,
|
|
tool_parser=tool_parser,
|
|
chat_template_content_format="auto",
|
|
)
|
|
self.resolved_model_id = model_id_for_vllm
|
|
self.model_ids.add(model.model_id)
|
|
|
|
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
|
|
|
return model
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
"""
|
|
Callback that is called when the server removes an inference endpoint from an inference
|
|
provider.
|
|
|
|
:param model_id: The same external ID that the higher layers of the stack previously passed
|
|
to :func:`register_model()`
|
|
"""
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
|
)
|
|
self.model_ids.remove(model_id)
|
|
|
|
if len(self.model_ids) == 0:
|
|
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
|
# resources.
|
|
# Note that this operation may cause in-flight chat completion requests on the
|
|
# now-unregistered model to return errors.
|
|
self.resolved_model_id = None
|
|
self.chat = None
|
|
self.engine.shutdown_background_loop()
|
|
self.engine = None
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM Inference INTERFACE
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
|
)
|
|
if not isinstance(content, str):
|
|
raise NotImplementedError("Multimodal input not currently supported")
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
|
|
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
|
|
|
logger.debug(f"{converted_sampling_params=}")
|
|
|
|
if stream:
|
|
return self._streaming_completion(content, converted_sampling_params)
|
|
else:
|
|
streaming_result = None
|
|
async for _ in self._streaming_completion(content, converted_sampling_params):
|
|
pass
|
|
return CompletionResponse(
|
|
content=streaming_result.delta,
|
|
stop_reason=streaming_result.stop_reason,
|
|
logprobs=streaming_result.logprobs,
|
|
)
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
output_dimension: Optional[int] = None,
|
|
task_type: Optional[EmbeddingTaskType] = None,
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message], # type: ignore
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None, # type: ignore
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
|
sampling_params = sampling_params or SamplingParams()
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
|
)
|
|
|
|
# Convert to Llama Stack internal format for consistency
|
|
request = ChatCompletionRequest(
|
|
model=self.resolved_model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
if self.is_meta_llama_model:
|
|
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
|
# templating layer in Llama Stack currently produces better results.
|
|
logger.debug(
|
|
f"Routing {self.resolved_model_id} chat completion through "
|
|
f"Llama Stack's templating layer instead of vLLM's."
|
|
)
|
|
return await self._chat_completion_for_meta_llama(request)
|
|
|
|
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
|
|
|
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
|
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
|
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
|
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
|
|
|
logger.debug(f"Converted request: {chat_completion_request}")
|
|
|
|
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
|
logger.debug(f"Result from vLLM: {vllm_result}")
|
|
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
|
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
|
|
|
# Return type depends on "stream" argument
|
|
if stream:
|
|
if not isinstance(vllm_result, AsyncGenerator):
|
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
|
# vLLM client returns a stream of strings, which need to be parsed.
|
|
# Stream comes in the form of an async generator.
|
|
return self._convert_streaming_results(vllm_result)
|
|
else:
|
|
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
|
return self._convert_non_streaming_results(vllm_result)
|
|
|
|
###########################################################################
|
|
# INTERNAL METHODS
|
|
|
|
async def _streaming_completion(
|
|
self, content: str, sampling_params: vllm.SamplingParams
|
|
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
|
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
|
that arguments have been validated upstream.
|
|
|
|
:param content: Must be a string
|
|
:param sampling_params: Paramters from public API's ``response_format``
|
|
and ``sampling_params`` arguments, converted to VLLM format
|
|
"""
|
|
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
|
# layer, because doing so simplifies the code here.
|
|
|
|
# The vLLM engine requires a unique identifier for each call to generate()
|
|
request_id = _random_uuid_str()
|
|
|
|
# The vLLM generate() API is streaming-only and returns an async generator.
|
|
# The generator returns objects of type vllm.RequestOutput.
|
|
results_generator = self.engine.generate(content, sampling_params, request_id)
|
|
|
|
# Need to know the model's EOS token ID for the conversion code below.
|
|
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
|
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
|
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
|
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
|
llm_engine = self.engine.engine
|
|
tokenizer_group = llm_engine.tokenizer
|
|
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
|
|
|
request_output: vllm.RequestOutput = None
|
|
async for request_output in results_generator:
|
|
# Check for weird inference failures
|
|
if request_output.outputs is None or len(request_output.outputs) == 0:
|
|
# This case also should never happen
|
|
raise ValueError("Inference produced empty result")
|
|
|
|
# If we get here, then request_output contains the final output of the generate() call.
|
|
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
|
# us to return one.
|
|
output: vllm.CompletionOutput = request_output.outputs[0]
|
|
completion_string = output.text
|
|
|
|
# Convert logprobs from vLLM's format to Llama Stack's format
|
|
logprobs = [
|
|
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
|
for logprob_dict in output.logprobs
|
|
]
|
|
|
|
# The final output chunk should be labeled with the reason that the overall generate()
|
|
# call completed.
|
|
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
|
if output.stop_reason is None:
|
|
stop_reason = None # Still going
|
|
elif output.stop_reason == "stop":
|
|
stop_reason = StopReason.end_of_turn
|
|
elif output.stop_reason == "length":
|
|
stop_reason = StopReason.out_of_tokens
|
|
elif isinstance(output.stop_reason, int):
|
|
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
|
# will return the token ID of the EOS token in the stop_reason field.
|
|
stop_reason = StopReason.end_of_turn
|
|
else:
|
|
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
|
|
|
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
|
# some reason.
|
|
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
|
stop_reason = StopReason.end_of_message
|
|
|
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
|
|
|
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
|
# provide one if it runs out of tokens.
|
|
if stop_reason is None:
|
|
yield CompletionResponseStreamChunk(
|
|
delta=completion_string,
|
|
stop_reason=StopReason.out_of_tokens,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
def _convert_non_streaming_results(
|
|
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
|
equivalent Llama Stack object.
|
|
|
|
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
|
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
|
the fields that aren't currently present in the Llama Stack dataclass.
|
|
"""
|
|
|
|
# There may be multiple responses, but we can only pass through the first one.
|
|
if len(vllm_result.choices) == 0:
|
|
raise ValueError("Don't know how to convert response object without any responses")
|
|
vllm_message = vllm_result.choices[0].message
|
|
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
|
|
|
converted_message = CompletionMessage(
|
|
role=vllm_message.role,
|
|
# Llama Stack API won't accept None for content field.
|
|
content=("" if vllm_message.content is None else vllm_message.content),
|
|
stop_reason=get_stop_reason(vllm_finish_reason),
|
|
tool_calls=[
|
|
ToolCall(
|
|
call_id=t.id,
|
|
tool_name=t.function.name,
|
|
# vLLM function args come back as a string. Llama Stack expects JSON.
|
|
arguments=json.loads(t.function.arguments),
|
|
arguments_json=t.function.arguments,
|
|
)
|
|
for t in vllm_message.tool_calls
|
|
],
|
|
)
|
|
|
|
# TODO: Convert logprobs
|
|
|
|
logger.debug(f"Converted message: {converted_message}")
|
|
|
|
return ChatCompletionResponse(
|
|
completion_message=converted_message,
|
|
)
|
|
|
|
async def _chat_completion_for_meta_llama(
|
|
self, request: ChatCompletionRequest
|
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
|
"""
|
|
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
|
chat template instead of using vLLM's version of that template. The Llama Stack version
|
|
of the chat template currently produces more reliable outputs.
|
|
|
|
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
|
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
|
"""
|
|
formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
# Note that this function call modifies `request` in place.
|
|
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
|
|
|
model_id = list(self.model_ids)[0] # Any model ID will do here
|
|
completion_response_or_iterator = await self.completion(
|
|
model_id=model_id,
|
|
content=prompt,
|
|
sampling_params=request.sampling_params,
|
|
response_format=request.response_format,
|
|
stream=request.stream,
|
|
logprobs=request.logprobs,
|
|
)
|
|
|
|
if request.stream:
|
|
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
|
raise TypeError(
|
|
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
|
)
|
|
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
|
|
|
# elsif not request.stream:
|
|
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
|
raise TypeError(
|
|
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
|
)
|
|
completion_response: CompletionResponse = completion_response_or_iterator
|
|
raw_message = formatter.decode_assistant_message_from_content(
|
|
completion_response.content, completion_response.stop_reason
|
|
)
|
|
return ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=raw_message.content,
|
|
stop_reason=raw_message.stop_reason,
|
|
tool_calls=raw_message.tool_calls,
|
|
),
|
|
logprobs=completion_response.logprobs,
|
|
)
|
|
|
|
async def _chat_completion_for_meta_llama_streaming(
|
|
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
|
) -> AsyncIterator:
|
|
"""
|
|
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
|
method to keep asyncio happy.
|
|
"""
|
|
|
|
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
|
async def _generate_and_convert_to_openai_compat():
|
|
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
|
last_text_len = 0
|
|
async for chunk in results_iterator:
|
|
if chunk.stop_reason == StopReason.end_of_turn:
|
|
finish_reason = "stop"
|
|
elif chunk.stop_reason == StopReason.end_of_message:
|
|
finish_reason = "eos"
|
|
elif chunk.stop_reason == StopReason.out_of_tokens:
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = None
|
|
|
|
# Convert delta back to an actual delta
|
|
text_delta = chunk.delta[last_text_len:]
|
|
last_text_len = len(chunk.delta)
|
|
|
|
logger.debug(f"{text_delta=}; {finish_reason=}")
|
|
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
logger.debug(f"Returning chunk: {chunk}")
|
|
yield chunk
|
|
|
|
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
|
"""
|
|
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
|
API into a second async iterator that returns Llama Stack objects.
|
|
|
|
:param vllm_result: Stream of strings that need to be parsed
|
|
"""
|
|
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
|
# those chunks and output them at the end.
|
|
# This data structure holds the current set of partial tool calls.
|
|
index_to_tool_call: Dict[int, Dict] = dict()
|
|
|
|
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
|
# simplify logic below
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=None,
|
|
)
|
|
)
|
|
|
|
converted_stop_reason = None
|
|
async for chunk_str in vllm_result:
|
|
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
|
# end with "\n\n".
|
|
_prefix = "data: "
|
|
_suffix = "\n\n"
|
|
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
|
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
|
|
|
# In between the "data: " and newlines is an event record
|
|
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
|
|
|
# The end of the stream is indicated with "[DONE]"
|
|
if data_str == "[DONE]":
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
return
|
|
|
|
# Anything that is not "[DONE]" should be a JSON record
|
|
parsed_chunk = json.loads(data_str)
|
|
|
|
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
|
|
|
# The result may contain multiple completions, but Llama Stack APIs only support
|
|
# returning one.
|
|
first_choice = parsed_chunk["choices"][0]
|
|
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
|
delta_record = first_choice["delta"]
|
|
|
|
if "content" in delta_record:
|
|
# Text delta
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=TextDelta(text=delta_record["content"]),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
elif "tool_calls" in delta_record:
|
|
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
|
# calls, so buffer until we get a "tool calls" stop reason
|
|
for tc in delta_record["tool_calls"]:
|
|
index = tc["index"]
|
|
if index not in index_to_tool_call:
|
|
# First time this tool call is showing up
|
|
index_to_tool_call[index] = dict()
|
|
tool_call = index_to_tool_call[index]
|
|
if "id" in tc:
|
|
tool_call["call_id"] = tc["id"]
|
|
if "function" in tc:
|
|
if "name" in tc["function"]:
|
|
tool_call["tool_name"] = tc["function"]["name"]
|
|
if "arguments" in tc["function"]:
|
|
# Arguments comes in as pieces of a string
|
|
if "arguments_str" not in tool_call:
|
|
tool_call["arguments_str"] = ""
|
|
tool_call["arguments_str"] += tc["function"]["arguments"]
|
|
else:
|
|
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
|
|
|
if first_choice["finish_reason"] == "tool_calls":
|
|
# Special OpenAI code for "tool calls complete".
|
|
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
|
# call.
|
|
for tool_call_record in index_to_tool_call.values():
|
|
# Arguments come in as a string. Parse the completed string.
|
|
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
|
del tool_call_record["arguments_str"]
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
|
|
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
|
# normally.
|
|
raise ValueError("vLLM event stream ended without [DONE] message.")
|