mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Import updated vLLM code and get tests working
This commit is contained in:
parent
a55aab5958
commit
80c357f434
2 changed files with 624 additions and 156 deletions
|
@ -26,12 +26,22 @@ class VLLMConfig(BaseModel):
|
|||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
max_model_len: int = Field(
|
||||
default=4096, description="Maximum context length to use during serving."
|
||||
)
|
||||
max_num_seqs: int = Field(
|
||||
default=4, description="Maximum parallel batch size for generation"
|
||||
)
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
description=(
|
||||
"How much GPU memory will be allocated when this provider has finished "
|
||||
"loading, including memory that was already allocated before loading."
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -40,8 +50,10 @@ class VLLMConfig(BaseModel):
|
|||
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
|
||||
}
|
||||
|
||||
@field_validator("model")
|
||||
|
|
|
@ -4,20 +4,39 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
import re
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names,
|
||||
# so we import fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
|
||||
############################################################################
|
||||
# vLLM imports go here
|
||||
#
|
||||
# We deep-import the names that don't conflict with Llama Stack names
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
############################################################################
|
||||
# llama_stack imports go here
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
|
@ -35,69 +54,273 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
ResponseFormat,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
############################################################################
|
||||
# Package-local imports go here
|
||||
from .config import VLLMConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
############################################################################
|
||||
# Constants go here
|
||||
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers
|
||||
# for the full list of available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
"MllamaConfig": "llama3_json",
|
||||
"LlamaConfig": "llama3_json",
|
||||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
############################################################################
|
||||
# Package-global variables go here
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
############################################################################
|
||||
# Local functions go here
|
||||
|
||||
|
||||
def _random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
def _info(msg: str):
|
||||
time_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
print(f"{time_str}: {msg}")
|
||||
# logger.info(msg)
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _convert_finish_reason(finish_reason: str | None) -> str | None:
|
||||
"""Convert an OpenAI "finish_reason" result to the equivalent
|
||||
Llama Stack result code.
|
||||
"""
|
||||
# This conversion is currently a wild guess.
|
||||
if finish_reason is None:
|
||||
return None
|
||||
elif finish_reason == "stop":
|
||||
return StopReason.end_of_turn
|
||||
else:
|
||||
return StopReason.out_of_tokens
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Like Llama Stack, vLLM's OpenAI-compatible API also uses the name
|
||||
"ResponseFormat" to describe the object that is a wrapper around
|
||||
another object that is a wrapper around another object inside
|
||||
someone else's constrained decoding library.
|
||||
Here we translate from Llama Stack's wrapper code to vLLM's code
|
||||
that does the same.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding
|
||||
info. Can be ``None``, indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference
|
||||
layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
return vllm.sampling_params.GuidedDecodingParams()
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained
|
||||
# decoding than vLLM does. Translate the types that exist and
|
||||
# detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM
|
||||
# uses the string representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not "
|
||||
"currently implemented, because the reference "
|
||||
"implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: Optional[SamplingParams],
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from
|
||||
Llama Stack's format to vLLM's format."""
|
||||
if sampling_params is None:
|
||||
# In the absence of a user-provided sampling config, we use
|
||||
# Llama Stack defaults, which are different from vLLM defaults.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
# vLLM treats "k" differently for top-k sampling
|
||||
vllm_top_k = -1
|
||||
else:
|
||||
vllm_top_k = sampling_params.strategy.top_k
|
||||
else:
|
||||
vllm_top_k = -1
|
||||
|
||||
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
vllm_top_p = sampling_params.strategy.top_p
|
||||
# Llama Stack only allows temperature with top-P.
|
||||
vllm_temperature = sampling_params.strategy.temperature
|
||||
else:
|
||||
vllm_top_p = 1.0
|
||||
vllm_temperature = 0.0
|
||||
|
||||
# vLLM allows top-p and top-k at the same time.
|
||||
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
||||
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
||||
# Assume that vLLM's default stop token will work
|
||||
# stop_token_ids=[tokenizer.eos_token_id],
|
||||
temperature=vllm_temperature,
|
||||
top_p=vllm_top_p,
|
||||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
||||
def _convert_tools(
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
############################################################################
|
||||
# Class definitions go here
|
||||
|
||||
|
||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
"""Inference implementation for vLLM."""
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple
|
||||
models.
|
||||
|
||||
Requires the configuration parameters documented in the
|
||||
:class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
register_helper: ModelRegistryHelper
|
||||
model_ids: set[str]
|
||||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
self.engine = None
|
||||
lo
|
||||
_info(f"Config is: {self.config}")
|
||||
|
||||
async def initialize(self):
|
||||
log.info("Initializing vLLM inference provider.")
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
# people to find out was on by default.
|
||||
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
|
||||
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||
# The following are initialized when paths are bound to this provider
|
||||
self.resolved_model_id = None
|
||||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
|
||||
model = resolve_model(self.config.model)
|
||||
if model is None:
|
||||
raise ValueError(f"Unknown model {self.config.model}")
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS
|
||||
|
||||
if model.huggingface_repo is None:
|
||||
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during
|
||||
provider class instantiation, sometime after when __init__() is called
|
||||
and before any model registration methods or methods connected to a
|
||||
REST API are called.
|
||||
|
||||
# TODO -- there are a ton of options supported here ...
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model.huggingface_repo,
|
||||
tokenizer=model.huggingface_repo,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
)
|
||||
It's not clear what assumptions the class can make about the platform's
|
||||
initialization state here that can't be made during __init__(), and
|
||||
vLLM can't be started until we know what model it's supposed to be
|
||||
serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shut down the vLLM inference adapter."""
|
||||
log.info("Shutting down vLLM inference provider.")
|
||||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
||||
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
|
@ -111,33 +334,102 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
:returns: The input ``Model`` object. It may or may not be permissible
|
||||
to change fields before returning this object.
|
||||
"""
|
||||
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||
# The current version of this provided is hard-coded to serve only
|
||||
# the model specified in the YAML config file.
|
||||
configured_model = resolve_model(self.config.model)
|
||||
registered_model = resolve_model(model.model_id)
|
||||
_info(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
resolved_model_id = resolved_llama_model.huggingface_repo
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
resolved_model_id = model.provider_model_id
|
||||
|
||||
_info(f"Resolved model id: {resolved_model_id}")
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if resolved_model_id != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids "
|
||||
f"'{self.resolved_model_id}') and "
|
||||
f"'{resolved_model_id}') from one copy of "
|
||||
f"provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
return model
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=resolved_model_id,
|
||||
tokenizer=resolved_model_id,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
max_num_seqs=self.config.max_num_seqs,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser
|
||||
# manually. To choose a tool parser, we need to determine what
|
||||
# model architecture is being used. For now, we infer that
|
||||
# information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
||||
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
||||
else:
|
||||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
_info(f"{hf_config_class_name=}")
|
||||
_info(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
self.chat = OpenAIServingChat(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
# The layer below us will only see resolved model IDs
|
||||
BaseModelPath(resolved_model_id, resolved_model_id)
|
||||
],
|
||||
response_role="assistant",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
enable_auto_tools=True,
|
||||
tool_parser=tool_parser,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
self.resolved_model_id = resolved_model_id
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
_info(f"Finished preloading model: {resolved_model_id}")
|
||||
|
||||
if configured_model.core_model_id != registered_model.core_model_id:
|
||||
raise ValueError(
|
||||
f"Requested model '{model.identifier}' is different from "
|
||||
f"model '{self.config.model}' that this provider "
|
||||
f"is configured to serve"
|
||||
)
|
||||
return model
|
||||
|
||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
||||
options = get_sampling_options(sampling_params)
|
||||
if "repeat_penalty" in options:
|
||||
options["repetition_penalty"] = options["repeat_penalty"]
|
||||
del options["repeat_penalty"]
|
||||
|
||||
return VLLMSamplingParams(**options)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint
|
||||
from an inference provider.
|
||||
|
||||
The semantics of this callback are not clear. How should model_id
|
||||
be interpreted? What happens to pending requests?
|
||||
|
||||
:param model_id: Undocumented string parameter
|
||||
|
||||
:returns: Nothing, at least according to the spec
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM Inference INTERFACE
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -147,100 +439,264 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
raise NotImplementedError("Completion not implemented for vLLM")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.config.model)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> ChatCompletionResponse:
|
||||
outputs = [o async for o in results_generator]
|
||||
final_output = outputs[-1]
|
||||
|
||||
assert final_output is not None
|
||||
outputs = final_output.outputs
|
||||
finish_reason = outputs[-1].stop_reason
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=finish_reason,
|
||||
text="".join([output.text for output in outputs]),
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
cur = []
|
||||
async for chunk in results_generator:
|
||||
if not chunk.outputs:
|
||||
log.warning("Empty chunk received")
|
||||
continue
|
||||
|
||||
output = chunk.outputs[-1]
|
||||
|
||||
new_tokens = output.token_ids[len(cur) :]
|
||||
text = tokenizer.decode(new_tokens)
|
||||
cur.extend(new_tokens)
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=output.finish_reason,
|
||||
text=text,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: List[InterleavedContent], # type: ignore
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
# model_id: str,
|
||||
# messages: List[Message], # type: ignore
|
||||
# sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# tools: Optional[List[ToolDefinition]] = None,
|
||||
# tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
# tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
# response_format: Optional[ResponseFormat] = None,
|
||||
# stream: Optional[bool] = False,
|
||||
# logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
||||
# dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in
|
||||
# Llama Stack.
|
||||
converted_messages = [
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
|
||||
]
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
|
||||
converted_tools = _convert_tools(tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument
|
||||
# TODO: Convert logprobs argument
|
||||
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=converted_messages,
|
||||
tools=converted_tools,
|
||||
tool_choice=converted_tool_choice,
|
||||
stream=stream,
|
||||
# tool_prompt_format=tool_prompt_format,
|
||||
# logprobs=logprobs,
|
||||
)
|
||||
|
||||
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
|
||||
# keyword args instead of a vLLM SamplingParams object. Copy over
|
||||
# all the parts that we currently convert from Llama Stack format.
|
||||
for param_name in [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
]:
|
||||
setattr(
|
||||
chat_completion_request,
|
||||
param_name,
|
||||
getattr(converted_sampling_params, param_name),
|
||||
)
|
||||
|
||||
# Guided decoding parameters are further broken out
|
||||
if converted_sampling_params.guided_decoding is not None:
|
||||
g = converted_sampling_params.guided_decoding
|
||||
chat_completion_request.guided_json = g.json
|
||||
chat_completion_request.guided_regex = g.regex
|
||||
chat_completion_request.guided_grammar = g.grammar
|
||||
|
||||
_info(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
_info(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible
|
||||
API into an equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with
|
||||
the same name as the Llama Stack ChatCompletionResponse dataclass,
|
||||
but with more and different field names. We ignore the fields that
|
||||
aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the
|
||||
# first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
|
||||
converted_message = CompletionMessage(
|
||||
role=vllm_message.role,
|
||||
# Llama Stack API won't accept None for content field.
|
||||
content=("" if vllm_message.content is None else vllm_message.content),
|
||||
stop_reason=_convert_finish_reason(vllm_result.choices[0].finish_reason),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id=t.id,
|
||||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
_info(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
API into a second async iterator that returns Llama Stack objects.
|
||||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger
|
||||
# chunks. We build up those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: Dict[int, Dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event.
|
||||
# Use an empty one to simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream
|
||||
# will start with "data: " and end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
||||
|
||||
# In between the "data: " and newlines is an event record
|
||||
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
||||
|
||||
# The end of the stream is indicated with "[DONE]"
|
||||
if data_str == "[DONE]":
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
# print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs
|
||||
# only support returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = _convert_finish_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
||||
if "content" in delta_record:
|
||||
# Text delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=delta_record["content"]),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return
|
||||
# partial tool calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
# First time this tool call is showing up
|
||||
index_to_tool_call[index] = dict()
|
||||
tool_call = index_to_tool_call[index]
|
||||
if "id" in tc:
|
||||
tool_call["call_id"] = tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_call["tool_name"] = tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
# Arguments comes in as pieces of a string
|
||||
if "arguments_str" not in tool_call:
|
||||
tool_call["arguments_str"] = ""
|
||||
tool_call["arguments_str"] += tc["function"]["arguments"]
|
||||
else:
|
||||
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
||||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate
|
||||
# event per tool call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
del tool_call_record["arguments_str"]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(content=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream
|
||||
# before it ended normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue