forked from phoenix-oss/llama-stack-mirror
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
# What does this PR do? Adds a new endpoint that is compatible with OpenAI for embeddings api. `/openai/v1/embeddings` Added providers for OpenAI, LiteLLM and SentenceTransformer. ## Test Plan ``` LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/inference/test_openai_embeddings.py --embedding-model all-MiniLM-L6-v2,text-embedding-3-small,gemini/text-embedding-004 ```
811 lines
35 KiB
Python
811 lines
35 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import re
|
|
import uuid
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
|
|
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
|
# fully-qualified names
|
|
import vllm.entrypoints.openai.protocol
|
|
import vllm.sampling_params
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextDelta,
|
|
ToolCallDelta,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
OpenAIEmbeddingsResponse,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
TokenLogProbs,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
TopKSamplingStrategy,
|
|
TopPSamplingStrategy,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama import sku_list
|
|
from llama_stack.models.llama.datatypes import (
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
ModelsProtocolPrivate,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
OpenAICompletionToLlamaStackMixin,
|
|
get_stop_reason,
|
|
process_chat_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
)
|
|
|
|
from .config import VLLMConfig
|
|
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
|
|
|
# Map from Hugging Face model architecture name to appropriate tool parser.
|
|
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
|
# available parsers.
|
|
# TODO: Expand this list
|
|
CONFIG_TYPE_TO_TOOL_PARSER = {
|
|
"GraniteConfig": "granite",
|
|
"MllamaConfig": "llama3_json",
|
|
"LlamaConfig": "llama3_json",
|
|
}
|
|
DEFAULT_TOOL_PARSER = "pythonic"
|
|
|
|
|
|
logger = get_logger(__name__, category="inference")
|
|
|
|
|
|
def _random_uuid_str() -> str:
|
|
return str(uuid.uuid4().hex)
|
|
|
|
|
|
def _response_format_to_guided_decoding_params(
|
|
response_format: ResponseFormat | None, # type: ignore
|
|
) -> vllm.sampling_params.GuidedDecodingParams:
|
|
"""
|
|
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
|
|
|
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
|
indicating no constraints.
|
|
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
|
"""
|
|
if response_format is None:
|
|
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
|
# value that crashes the executor on some code paths. Use ``None`` instead.
|
|
return None
|
|
|
|
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
|
# Translate the types that exist and detect if Llama Stack adds new ones.
|
|
if isinstance(response_format, JsonSchemaResponseFormat):
|
|
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
|
elif isinstance(response_format, GrammarResponseFormat):
|
|
# BNF grammar.
|
|
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
|
# representation of the grammar.
|
|
raise TypeError(
|
|
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
|
"reference implementation does not implement it."
|
|
)
|
|
else:
|
|
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
|
|
|
|
|
def _convert_sampling_params(
|
|
sampling_params: SamplingParams | None,
|
|
response_format: ResponseFormat | None, # type: ignore
|
|
log_prob_config: LogProbConfig | None,
|
|
) -> vllm.SamplingParams:
|
|
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
|
format."""
|
|
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
|
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if log_prob_config is None:
|
|
log_prob_config = LogProbConfig()
|
|
|
|
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
|
if sampling_params.strategy.top_k == 0:
|
|
# vLLM treats "k" differently for top-k sampling
|
|
vllm_top_k = -1
|
|
else:
|
|
vllm_top_k = sampling_params.strategy.top_k
|
|
else:
|
|
vllm_top_k = -1
|
|
|
|
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
|
vllm_top_p = sampling_params.strategy.top_p
|
|
# Llama Stack only allows temperature with top-P.
|
|
vllm_temperature = sampling_params.strategy.temperature
|
|
else:
|
|
vllm_top_p = 1.0
|
|
vllm_temperature = 0.0
|
|
|
|
# vLLM allows top-p and top-k at the same time.
|
|
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
|
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
|
temperature=vllm_temperature,
|
|
top_p=vllm_top_p,
|
|
top_k=vllm_top_k,
|
|
repetition_penalty=sampling_params.repetition_penalty,
|
|
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
|
logprobs=log_prob_config.top_k,
|
|
)
|
|
return vllm_sampling_params
|
|
|
|
|
|
class VLLMInferenceImpl(
|
|
Inference,
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
OpenAICompletionToLlamaStackMixin,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
"""
|
|
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
|
|
|
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
|
"""
|
|
|
|
config: VLLMConfig
|
|
register_helper: ModelRegistryHelper
|
|
model_ids: set[str]
|
|
resolved_model_id: str | None
|
|
engine: AsyncLLMEngine | None
|
|
chat: OpenAIServingChat | None
|
|
is_meta_llama_model: bool
|
|
|
|
def __init__(self, config: VLLMConfig):
|
|
self.config = config
|
|
logger.info(f"Config is: {self.config}")
|
|
|
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
# The following are initialized when paths are bound to this provider
|
|
self.resolved_model_id = None
|
|
self.model_ids = set()
|
|
self.engine = None
|
|
self.chat = None
|
|
self.is_meta_llama_model = False
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
|
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
|
|
|
async def initialize(self) -> None:
|
|
"""
|
|
Callback that is invoked through many levels of indirection during provider class
|
|
instantiation, sometime after when __init__() is called and before any model registration
|
|
methods or methods connected to a REST API are called.
|
|
|
|
It's not clear what assumptions the class can make about the platform's initialization
|
|
state here that can't be made during __init__(), and vLLM can't be started until we know
|
|
what model it's supposed to be serving, so nothing happens here currently.
|
|
"""
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
|
if self.engine is not None:
|
|
self.engine.shutdown_background_loop()
|
|
self.engine = None
|
|
self.chat = None
|
|
self.model_ids = set()
|
|
self.resolved_model_id = None
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
|
|
|
# Note that the return type of the superclass method is WRONG
|
|
async def register_model(self, model: Model) -> Model:
|
|
"""
|
|
Callback that is called when the server associates an inference endpoint with an
|
|
inference provider.
|
|
|
|
:param model: Object that encapsulates parameters necessary for identifying a specific
|
|
LLM.
|
|
|
|
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
|
before returning this object.
|
|
"""
|
|
logger.debug(f"In register_model({model})")
|
|
|
|
# First attempt to interpret the model coordinates as a Llama model name
|
|
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
|
if resolved_llama_model is not None:
|
|
# Load from Hugging Face repo into default local cache dir
|
|
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
|
|
|
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
|
# Don't set self.is_meta_llama_model until we actually load the model.
|
|
is_meta_llama_model = True
|
|
else: # if resolved_llama_model is None
|
|
# Not a Llama model name. Pass the model id through to vLLM's loader
|
|
model_id_for_vllm = model.provider_model_id
|
|
is_meta_llama_model = False
|
|
|
|
if self.resolved_model_id is not None:
|
|
if model_id_for_vllm != self.resolved_model_id:
|
|
raise ValueError(
|
|
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
|
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
|
f"copies of the provider instead."
|
|
)
|
|
else:
|
|
# Model already loaded
|
|
logger.info(
|
|
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
|
)
|
|
self.model_ids.add(model.model_id)
|
|
return model
|
|
|
|
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
|
if is_meta_llama_model:
|
|
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
|
self.is_meta_llama_model = is_meta_llama_model
|
|
|
|
# If we get here, this is the first time registering a model.
|
|
# Preload so that the first inference request won't time out.
|
|
engine_args = AsyncEngineArgs(
|
|
model=model_id_for_vllm,
|
|
tokenizer=model_id_for_vllm,
|
|
tensor_parallel_size=self.config.tensor_parallel_size,
|
|
enforce_eager=self.config.enforce_eager,
|
|
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
|
max_num_seqs=self.config.max_num_seqs,
|
|
max_model_len=self.config.max_model_len,
|
|
)
|
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
|
# parser, we need to determine what model architecture is being used. For now, we infer
|
|
# that information from what config class the model uses.
|
|
low_level_model_config = self.engine.engine.get_model_config()
|
|
hf_config = low_level_model_config.hf_config
|
|
hf_config_class_name = hf_config.__class__.__name__
|
|
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
|
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
|
else:
|
|
# No info -- choose a default so we can at least attempt tool
|
|
# use.
|
|
tool_parser = DEFAULT_TOOL_PARSER
|
|
logger.debug(f"{hf_config_class_name=}")
|
|
logger.debug(f"{tool_parser=}")
|
|
|
|
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
|
model_config = await self.engine.get_model_config()
|
|
self.chat = OpenAIServingChat(
|
|
engine_client=self.engine,
|
|
model_config=model_config,
|
|
models=OpenAIServingModels(
|
|
engine_client=self.engine,
|
|
model_config=model_config,
|
|
base_model_paths=[
|
|
# The layer below us will only see resolved model IDs
|
|
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
|
],
|
|
),
|
|
response_role="assistant",
|
|
request_logger=None, # Use default logging
|
|
chat_template=None, # Use default template from model checkpoint
|
|
enable_auto_tools=True,
|
|
tool_parser=tool_parser,
|
|
chat_template_content_format="auto",
|
|
)
|
|
self.resolved_model_id = model_id_for_vllm
|
|
self.model_ids.add(model.model_id)
|
|
|
|
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
|
|
|
return model
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
"""
|
|
Callback that is called when the server removes an inference endpoint from an inference
|
|
provider.
|
|
|
|
:param model_id: The same external ID that the higher layers of the stack previously passed
|
|
to :func:`register_model()`
|
|
"""
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
|
)
|
|
self.model_ids.remove(model_id)
|
|
|
|
if len(self.model_ids) == 0:
|
|
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
|
# resources.
|
|
# Note that this operation may cause in-flight chat completion requests on the
|
|
# now-unregistered model to return errors.
|
|
self.resolved_model_id = None
|
|
self.chat = None
|
|
self.engine.shutdown_background_loop()
|
|
self.engine = None
|
|
|
|
###########################################################################
|
|
# METHODS INHERITED FROM Inference INTERFACE
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: SamplingParams | None = None,
|
|
response_format: ResponseFormat | None = None,
|
|
stream: bool | None = False,
|
|
logprobs: LogProbConfig | None = None,
|
|
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
|
)
|
|
if not isinstance(content, str):
|
|
raise NotImplementedError("Multimodal input not currently supported")
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
|
|
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
|
|
|
logger.debug(f"{converted_sampling_params=}")
|
|
|
|
if stream:
|
|
return self._streaming_completion(content, converted_sampling_params)
|
|
else:
|
|
streaming_result = None
|
|
async for _ in self._streaming_completion(content, converted_sampling_params):
|
|
pass
|
|
return CompletionResponse(
|
|
content=streaming_result.delta,
|
|
stop_reason=streaming_result.stop_reason,
|
|
logprobs=streaming_result.logprobs,
|
|
)
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: list[str] | list[InterleavedContentItem],
|
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
|
output_dimension: int | None = None,
|
|
task_type: EmbeddingTaskType | None = None,
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
model: str,
|
|
input: str | list[str],
|
|
encoding_format: str | None = "float",
|
|
dimensions: int | None = None,
|
|
user: str | None = None,
|
|
) -> OpenAIEmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: list[Message], # type: ignore
|
|
sampling_params: SamplingParams | None = None,
|
|
response_format: ResponseFormat | None = None, # type: ignore
|
|
tools: list[ToolDefinition] | None = None,
|
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
stream: bool | None = False,
|
|
logprobs: LogProbConfig | None = None,
|
|
tool_config: ToolConfig | None = None,
|
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
|
sampling_params = sampling_params or SamplingParams()
|
|
if model_id not in self.model_ids:
|
|
raise ValueError(
|
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
|
)
|
|
|
|
# Convert to Llama Stack internal format for consistency
|
|
request = ChatCompletionRequest(
|
|
model=self.resolved_model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
if self.is_meta_llama_model:
|
|
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
|
# templating layer in Llama Stack currently produces better results.
|
|
logger.debug(
|
|
f"Routing {self.resolved_model_id} chat completion through "
|
|
f"Llama Stack's templating layer instead of vLLM's."
|
|
)
|
|
return await self._chat_completion_for_meta_llama(request)
|
|
|
|
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
|
|
|
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
|
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
|
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
|
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
|
|
|
logger.debug(f"Converted request: {chat_completion_request}")
|
|
|
|
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
|
logger.debug(f"Result from vLLM: {vllm_result}")
|
|
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
|
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
|
|
|
# Return type depends on "stream" argument
|
|
if stream:
|
|
if not isinstance(vllm_result, AsyncGenerator):
|
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
|
# vLLM client returns a stream of strings, which need to be parsed.
|
|
# Stream comes in the form of an async generator.
|
|
return self._convert_streaming_results(vllm_result)
|
|
else:
|
|
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
|
return self._convert_non_streaming_results(vllm_result)
|
|
|
|
###########################################################################
|
|
# INTERNAL METHODS
|
|
|
|
async def _streaming_completion(
|
|
self, content: str, sampling_params: vllm.SamplingParams
|
|
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
|
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
|
that arguments have been validated upstream.
|
|
|
|
:param content: Must be a string
|
|
:param sampling_params: Paramters from public API's ``response_format``
|
|
and ``sampling_params`` arguments, converted to VLLM format
|
|
"""
|
|
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
|
# layer, because doing so simplifies the code here.
|
|
|
|
# The vLLM engine requires a unique identifier for each call to generate()
|
|
request_id = _random_uuid_str()
|
|
|
|
# The vLLM generate() API is streaming-only and returns an async generator.
|
|
# The generator returns objects of type vllm.RequestOutput.
|
|
results_generator = self.engine.generate(content, sampling_params, request_id)
|
|
|
|
# Need to know the model's EOS token ID for the conversion code below.
|
|
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
|
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
|
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
|
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
|
llm_engine = self.engine.engine
|
|
tokenizer_group = llm_engine.tokenizer
|
|
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
|
|
|
request_output: vllm.RequestOutput = None
|
|
async for request_output in results_generator:
|
|
# Check for weird inference failures
|
|
if request_output.outputs is None or len(request_output.outputs) == 0:
|
|
# This case also should never happen
|
|
raise ValueError("Inference produced empty result")
|
|
|
|
# If we get here, then request_output contains the final output of the generate() call.
|
|
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
|
# us to return one.
|
|
output: vllm.CompletionOutput = request_output.outputs[0]
|
|
completion_string = output.text
|
|
|
|
# Convert logprobs from vLLM's format to Llama Stack's format
|
|
logprobs = [
|
|
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
|
for logprob_dict in output.logprobs
|
|
]
|
|
|
|
# The final output chunk should be labeled with the reason that the overall generate()
|
|
# call completed.
|
|
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
|
if output.stop_reason is None:
|
|
stop_reason = None # Still going
|
|
elif output.stop_reason == "stop":
|
|
stop_reason = StopReason.end_of_turn
|
|
elif output.stop_reason == "length":
|
|
stop_reason = StopReason.out_of_tokens
|
|
elif isinstance(output.stop_reason, int):
|
|
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
|
# will return the token ID of the EOS token in the stop_reason field.
|
|
stop_reason = StopReason.end_of_turn
|
|
else:
|
|
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
|
|
|
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
|
# some reason.
|
|
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
|
stop_reason = StopReason.end_of_message
|
|
|
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
|
|
|
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
|
# provide one if it runs out of tokens.
|
|
if stop_reason is None:
|
|
yield CompletionResponseStreamChunk(
|
|
delta=completion_string,
|
|
stop_reason=StopReason.out_of_tokens,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
def _convert_non_streaming_results(
|
|
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
|
equivalent Llama Stack object.
|
|
|
|
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
|
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
|
the fields that aren't currently present in the Llama Stack dataclass.
|
|
"""
|
|
|
|
# There may be multiple responses, but we can only pass through the first one.
|
|
if len(vllm_result.choices) == 0:
|
|
raise ValueError("Don't know how to convert response object without any responses")
|
|
vllm_message = vllm_result.choices[0].message
|
|
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
|
|
|
converted_message = CompletionMessage(
|
|
role=vllm_message.role,
|
|
# Llama Stack API won't accept None for content field.
|
|
content=("" if vllm_message.content is None else vllm_message.content),
|
|
stop_reason=get_stop_reason(vllm_finish_reason),
|
|
tool_calls=[
|
|
ToolCall(
|
|
call_id=t.id,
|
|
tool_name=t.function.name,
|
|
# vLLM function args come back as a string. Llama Stack expects JSON.
|
|
arguments=json.loads(t.function.arguments),
|
|
arguments_json=t.function.arguments,
|
|
)
|
|
for t in vllm_message.tool_calls
|
|
],
|
|
)
|
|
|
|
# TODO: Convert logprobs
|
|
|
|
logger.debug(f"Converted message: {converted_message}")
|
|
|
|
return ChatCompletionResponse(
|
|
completion_message=converted_message,
|
|
)
|
|
|
|
async def _chat_completion_for_meta_llama(
|
|
self, request: ChatCompletionRequest
|
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
|
"""
|
|
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
|
chat template instead of using vLLM's version of that template. The Llama Stack version
|
|
of the chat template currently produces more reliable outputs.
|
|
|
|
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
|
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
|
"""
|
|
formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
# Note that this function call modifies `request` in place.
|
|
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
|
|
|
model_id = list(self.model_ids)[0] # Any model ID will do here
|
|
completion_response_or_iterator = await self.completion(
|
|
model_id=model_id,
|
|
content=prompt,
|
|
sampling_params=request.sampling_params,
|
|
response_format=request.response_format,
|
|
stream=request.stream,
|
|
logprobs=request.logprobs,
|
|
)
|
|
|
|
if request.stream:
|
|
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
|
raise TypeError(
|
|
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
|
)
|
|
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
|
|
|
# elsif not request.stream:
|
|
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
|
raise TypeError(
|
|
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
|
)
|
|
completion_response: CompletionResponse = completion_response_or_iterator
|
|
raw_message = formatter.decode_assistant_message_from_content(
|
|
completion_response.content, completion_response.stop_reason
|
|
)
|
|
return ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=raw_message.content,
|
|
stop_reason=raw_message.stop_reason,
|
|
tool_calls=raw_message.tool_calls,
|
|
),
|
|
logprobs=completion_response.logprobs,
|
|
)
|
|
|
|
async def _chat_completion_for_meta_llama_streaming(
|
|
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
|
) -> AsyncIterator:
|
|
"""
|
|
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
|
method to keep asyncio happy.
|
|
"""
|
|
|
|
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
|
async def _generate_and_convert_to_openai_compat():
|
|
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
|
last_text_len = 0
|
|
async for chunk in results_iterator:
|
|
if chunk.stop_reason == StopReason.end_of_turn:
|
|
finish_reason = "stop"
|
|
elif chunk.stop_reason == StopReason.end_of_message:
|
|
finish_reason = "eos"
|
|
elif chunk.stop_reason == StopReason.out_of_tokens:
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = None
|
|
|
|
# Convert delta back to an actual delta
|
|
text_delta = chunk.delta[last_text_len:]
|
|
last_text_len = len(chunk.delta)
|
|
|
|
logger.debug(f"{text_delta=}; {finish_reason=}")
|
|
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
logger.debug(f"Returning chunk: {chunk}")
|
|
yield chunk
|
|
|
|
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
|
"""
|
|
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
|
API into a second async iterator that returns Llama Stack objects.
|
|
|
|
:param vllm_result: Stream of strings that need to be parsed
|
|
"""
|
|
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
|
# those chunks and output them at the end.
|
|
# This data structure holds the current set of partial tool calls.
|
|
index_to_tool_call: dict[int, dict] = dict()
|
|
|
|
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
|
# simplify logic below
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=None,
|
|
)
|
|
)
|
|
|
|
converted_stop_reason = None
|
|
async for chunk_str in vllm_result:
|
|
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
|
# end with "\n\n".
|
|
_prefix = "data: "
|
|
_suffix = "\n\n"
|
|
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
|
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
|
|
|
# In between the "data: " and newlines is an event record
|
|
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
|
|
|
# The end of the stream is indicated with "[DONE]"
|
|
if data_str == "[DONE]":
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
return
|
|
|
|
# Anything that is not "[DONE]" should be a JSON record
|
|
parsed_chunk = json.loads(data_str)
|
|
|
|
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
|
|
|
# The result may contain multiple completions, but Llama Stack APIs only support
|
|
# returning one.
|
|
first_choice = parsed_chunk["choices"][0]
|
|
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
|
delta_record = first_choice["delta"]
|
|
|
|
if "content" in delta_record:
|
|
# Text delta
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=TextDelta(text=delta_record["content"]),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
elif "tool_calls" in delta_record:
|
|
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
|
# calls, so buffer until we get a "tool calls" stop reason
|
|
for tc in delta_record["tool_calls"]:
|
|
index = tc["index"]
|
|
if index not in index_to_tool_call:
|
|
# First time this tool call is showing up
|
|
index_to_tool_call[index] = dict()
|
|
tool_call = index_to_tool_call[index]
|
|
if "id" in tc:
|
|
tool_call["call_id"] = tc["id"]
|
|
if "function" in tc:
|
|
if "name" in tc["function"]:
|
|
tool_call["tool_name"] = tc["function"]["name"]
|
|
if "arguments" in tc["function"]:
|
|
# Arguments comes in as pieces of a string
|
|
if "arguments_str" not in tool_call:
|
|
tool_call["arguments_str"] = ""
|
|
tool_call["arguments_str"] += tc["function"]["arguments"]
|
|
else:
|
|
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
|
|
|
if first_choice["finish_reason"] == "tool_calls":
|
|
# Special OpenAI code for "tool calls complete".
|
|
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
|
# call.
|
|
for tool_call_record in index_to_tool_call.values():
|
|
# Arguments come in as a string. Parse the completed string.
|
|
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
|
del tool_call_record["arguments_str"]
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
|
stop_reason=converted_stop_reason,
|
|
)
|
|
)
|
|
|
|
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
|
# normally.
|
|
raise ValueError("vLLM event stream ended without [DONE] message.")
|