diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index 51ef2d273..5921b132d 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -26,12 +26,22 @@ class VLLMConfig(BaseModel): default=4096, description="Maximum number of tokens to generate.", ) + max_model_len: int = Field( + default=4096, description="Maximum context length to use during serving." + ) + max_num_seqs: int = Field( + default=4, description="Maximum parallel batch size for generation" + ) enforce_eager: bool = Field( default=False, description="Whether to use eager mode for inference (otherwise cuda graphs are used).", ) gpu_memory_utilization: float = Field( default=0.3, + description=( + "How much GPU memory will be allocated when this provider has finished " + "loading, including memory that was already allocated before loading." + ), ) @classmethod @@ -40,8 +50,10 @@ class VLLMConfig(BaseModel): "model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}", "tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}", "max_tokens": "${env.MAX_TOKENS:4096}", + "max_model_len": "${env.MAX_MODEL_LEN:4096}", + "max_num_seqs": "${env.MAX_NUM_SEQS:4}", "enforce_eager": "${env.ENFORCE_EAGER:False}", - "gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}", + "gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}", } @field_validator("model") diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index b461bf44a..e40ab5fdc 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -4,20 +4,39 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import datetime +import json import logging -import os -import uuid -from typing import AsyncGenerator, List, Optional +import re +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +# These vLLM modules contain names that overlap with Llama Stack names, +# so we import fully-qualified names +import vllm.entrypoints.openai.protocol +import vllm.sampling_params + +############################################################################ +# vLLM imports go here +# +# We deep-import the names that don't conflict with Llama Stack names from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.sampling_params import SamplingParams as VLLMSamplingParams +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import BaseModelPath -from llama_stack.apis.common.content_types import InterleavedContent +############################################################################ +# llama_stack imports go here +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextDelta, + ToolCallDelta, +) from llama_stack.apis.inference import ( - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, + CompletionMessage, CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, @@ -35,69 +54,273 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.remote.inference.vllm.vllm import build_model_aliases +from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, + ModelsProtocolPrivate, +) from llama_stack.providers.utils.inference.openai_compat import ( + GrammarResponseFormat, + Inference, + JsonSchemaResponseFormat, + LogProbConfig, + Message, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + ResponseFormat, + ToolCall, + ToolChoice, + UserMessage, + convert_message_to_openai_dict, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, -) +############################################################################ +# Package-local imports go here from .config import VLLMConfig -log = logging.getLogger(__name__) +############################################################################ +# Constants go here + +# Map from Hugging Face model architecture name to appropriate tool parser. +# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers +# for the full list of available parsers. +# TODO: Expand this list +CONFIG_TYPE_TO_TOOL_PARSER = { + "GraniteConfig": "granite", + "MllamaConfig": "llama3_json", + "LlamaConfig": "llama3_json", +} +DEFAULT_TOOL_PARSER = "pythonic" + +############################################################################ +# Package-global variables go here + +logger = logging.getLogger(__name__) + +############################################################################ +# Local functions go here -def _random_uuid() -> str: - return str(uuid.uuid4().hex) +def _info(msg: str): + time_str = datetime.datetime.now().strftime("%H:%M:%S") + print(f"{time_str}: {msg}") + # logger.info(msg) + + +def _merge_context_into_content(message: Message) -> Message: # type: ignore + """ + Merge the ``context`` field of a Llama Stack ``Message`` object into + the content field for compabilitiy with OpenAI-style APIs. + + Generates a content string that emulates the current behavior + of ``llama_models.llama3.api.chat_format.encode_message()``. + + :param message: Message that may include ``context`` field + + :returns: A version of ``message`` with any context merged into the + ``content`` field. + """ + if not isinstance(message, UserMessage): # Separate type check for linter + return message + if message.context is None: + return message + return UserMessage( + role=message.role, + # Emumate llama_models.llama3.api.chat_format.encode_message() + content=message.content + "\n\n" + message.context, + context=None, + ) + + +def _convert_finish_reason(finish_reason: str | None) -> str | None: + """Convert an OpenAI "finish_reason" result to the equivalent + Llama Stack result code. + """ + # This conversion is currently a wild guess. + if finish_reason is None: + return None + elif finish_reason == "stop": + return StopReason.end_of_turn + else: + return StopReason.out_of_tokens + + +def _response_format_to_guided_decoding_params( + response_format: Optional[ResponseFormat], # type: ignore +) -> vllm.sampling_params.GuidedDecodingParams: + """ + Like Llama Stack, vLLM's OpenAI-compatible API also uses the name + "ResponseFormat" to describe the object that is a wrapper around + another object that is a wrapper around another object inside + someone else's constrained decoding library. + Here we translate from Llama Stack's wrapper code to vLLM's code + that does the same. + + :param response_format: Llama Stack version of constrained decoding + info. Can be ``None``, indicating no constraints. + :returns: The equivalent dataclass object for the low-level inference + layer of vLLM. + """ + if response_format is None: + return vllm.sampling_params.GuidedDecodingParams() + + # Llama Stack currently implements fewer types of constrained + # decoding than vLLM does. Translate the types that exist and + # detect if Llama Stack adds new ones. + if isinstance(response_format, JsonSchemaResponseFormat): + return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema) + elif isinstance(response_format, GrammarResponseFormat): + # BNF grammar. + # Llama Stack uses the parse tree of the grammar, while vLLM + # uses the string representation of the grammar. + raise TypeError( + "Constrained decoding with BNF grammars is not " + "currently implemented, because the reference " + "implementation does not implement it." + ) + else: + raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'") + + +def _convert_sampling_params( + sampling_params: Optional[SamplingParams], + response_format: Optional[ResponseFormat], # type: ignore +) -> vllm.SamplingParams: + """Convert sampling and constrained decoding configuration from + Llama Stack's format to vLLM's format.""" + if sampling_params is None: + # In the absence of a user-provided sampling config, we use + # Llama Stack defaults, which are different from vLLM defaults. + sampling_params = SamplingParams() + + if isinstance(sampling_params.strategy, TopKSamplingStrategy): + if sampling_params.strategy.top_k == 0: + # vLLM treats "k" differently for top-k sampling + vllm_top_k = -1 + else: + vllm_top_k = sampling_params.strategy.top_k + else: + vllm_top_k = -1 + + if isinstance(sampling_params.strategy, TopPSamplingStrategy): + vllm_top_p = sampling_params.strategy.top_p + # Llama Stack only allows temperature with top-P. + vllm_temperature = sampling_params.strategy.temperature + else: + vllm_top_p = 1.0 + vllm_temperature = 0.0 + + # vLLM allows top-p and top-k at the same time. + vllm_sampling_params = vllm.SamplingParams.from_optional( + max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens), + # Assume that vLLM's default stop token will work + # stop_token_ids=[tokenizer.eos_token_id], + temperature=vllm_temperature, + top_p=vllm_top_p, + top_k=vllm_top_k, + repetition_penalty=sampling_params.repetition_penalty, + guided_decoding=_response_format_to_guided_decoding_params(response_format), + ) + return vllm_sampling_params + + +def _convert_tools( + tools: Optional[List[ToolDefinition]] = None, +) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]: + """ + Convert the list of available tools from Llama Stack's format to vLLM's + version of OpenAI's format. + """ + if tools is None: + return [] + + result = [] + for t in tools: + if isinstance(t.tool_name, BuiltinTool): + raise NotImplementedError("Built-in tools not yet implemented") + if t.parameters is None: + parameters = None + else: # if t.parameters is not None + # Convert the "required" flags to a list of required params + required_params = [k for k, v in t.parameters.items() if v.required] + parameters = { + "type": "object", # Mystery value that shows up in OpenAI docs + "properties": { + k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items() + }, + "required": required_params, + } + + function_def = vllm.entrypoints.openai.protocol.FunctionDefinition( + name=t.tool_name, description=t.description, parameters=parameters + ) + + # Every tool definition is double-boxed in a ChatCompletionToolsParam + result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def)) + return result + + +############################################################################ +# Class definitions go here class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): - """Inference implementation for vLLM.""" + """ + vLLM-based inference model adapter for Llama Stack with support for multiple + models. + + Requires the configuration parameters documented in the + :class:`VllmConfig2` class. + """ + + config: VLLMConfig + register_helper: ModelRegistryHelper + model_ids: set[str] + resolved_model_id: str | None + engine: AsyncLLMEngine | None + chat: OpenAIServingChat | None def __init__(self, config: VLLMConfig): self.config = config self.engine = None + lo + _info(f"Config is: {self.config}") - async def initialize(self): - log.info("Initializing vLLM inference provider.") + self.register_helper = ModelRegistryHelper(build_model_aliases()) + self.formatter = ChatFormat(Tokenizer.get_instance()) - # Disable usage stats reporting. This would be a surprising thing for most - # people to find out was on by default. - # https://docs.vllm.ai/en/latest/serving/usage_stats.html - if "VLLM_NO_USAGE_STATS" not in os.environ: - os.environ["VLLM_NO_USAGE_STATS"] = "1" + # The following are initialized when paths are bound to this provider + self.resolved_model_id = None + self.model_ids = set() + self.engine = None + self.chat = None - model = resolve_model(self.config.model) - if model is None: - raise ValueError(f"Unknown model {self.config.model}") + ########################################################################### + # METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS - if model.huggingface_repo is None: - raise ValueError(f"Model {self.config.model} needs a huggingface repo") + async def initialize(self) -> None: + """ + Callback that is invoked through many levels of indirection during + provider class instantiation, sometime after when __init__() is called + and before any model registration methods or methods connected to a + REST API are called. - # TODO -- there are a ton of options supported here ... - engine_args = AsyncEngineArgs( - model=model.huggingface_repo, - tokenizer=model.huggingface_repo, - tensor_parallel_size=self.config.tensor_parallel_size, - enforce_eager=self.config.enforce_eager, - gpu_memory_utilization=self.config.gpu_memory_utilization, - guided_decoding_backend="lm-format-enforcer", - ) + It's not clear what assumptions the class can make about the platform's + initialization state here that can't be made during __init__(), and + vLLM can't be started until we know what model it's supposed to be + serving, so nothing happens here currently. + """ + pass - self.engine = AsyncLLMEngine.from_engine_args(engine_args) - - async def shutdown(self): - """Shut down the vLLM inference adapter.""" - log.info("Shutting down vLLM inference provider.") - if self.engine: - self.engine.shutdown_background_loop() + ########################################################################### + # METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE # Note that the return type of the superclass method is WRONG async def register_model(self, model: Model) -> Model: @@ -111,33 +334,102 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): :returns: The input ``Model`` object. It may or may not be permissible to change fields before returning this object. """ - log.info(f"Registering model {model.identifier} with vLLM inference provider.") - # The current version of this provided is hard-coded to serve only - # the model specified in the YAML config file. - configured_model = resolve_model(self.config.model) - registered_model = resolve_model(model.model_id) + _info(f"In register_model({model})") + + # First attempt to interpret the model coordinates as a Llama model name + resolved_llama_model = resolve_model(model.provider_model_id) + if resolved_llama_model is not None: + # Load from Hugging Face repo into default local cache dir + resolved_model_id = resolved_llama_model.huggingface_repo + else: # if resolved_llama_model is None + # Not a Llama model name. Pass the model id through to vLLM's loader + resolved_model_id = model.provider_model_id + + _info(f"Resolved model id: {resolved_model_id}") + + if self.resolved_model_id is not None: + if resolved_model_id != self.resolved_model_id: + raise ValueError( + f"Attempted to serve two LLMs (ids " + f"'{self.resolved_model_id}') and " + f"'{resolved_model_id}') from one copy of " + f"provider '{self}'. Use multiple " + f"copies of the provider instead." + ) + else: + # Model already loaded + return model + + # If we get here, this is the first time registering a model. + # Preload so that the first inference request won't time out. + engine_args = AsyncEngineArgs( + model=resolved_model_id, + tokenizer=resolved_model_id, + tensor_parallel_size=self.config.tensor_parallel_size, + enforce_eager=self.config.enforce_eager, + gpu_memory_utilization=self.config.gpu_memory_utilization, + max_num_seqs=self.config.max_num_seqs, + max_model_len=self.config.max_model_len, + ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + # vLLM currently requires the user to specify the tool parser + # manually. To choose a tool parser, we need to determine what + # model architecture is being used. For now, we infer that + # information from what config class the model uses. + low_level_model_config = self.engine.engine.get_model_config() + hf_config = low_level_model_config.hf_config + hf_config_class_name = hf_config.__class__.__name__ + if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER: + tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name] + else: + # No info -- choose a default so we can at least attempt tool + # use. + tool_parser = DEFAULT_TOOL_PARSER + _info(f"{hf_config_class_name=}") + _info(f"{tool_parser=}") + + # Wrap the lower-level engine in an OpenAI-compatible chat API + model_config = await self.engine.get_model_config() + self.chat = OpenAIServingChat( + engine_client=self.engine, + model_config=model_config, + base_model_paths=[ + # The layer below us will only see resolved model IDs + BaseModelPath(resolved_model_id, resolved_model_id) + ], + response_role="assistant", + lora_modules=None, + prompt_adapters=None, + request_logger=None, + chat_template=None, + enable_auto_tools=True, + tool_parser=tool_parser, + chat_template_content_format="auto", + ) + self.resolved_model_id = resolved_model_id + self.model_ids.add(model.model_id) + + _info(f"Finished preloading model: {resolved_model_id}") - if configured_model.core_model_id != registered_model.core_model_id: - raise ValueError( - f"Requested model '{model.identifier}' is different from " - f"model '{self.config.model}' that this provider " - f"is configured to serve" - ) return model - def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: - if sampling_params is None: - return VLLMSamplingParams(max_tokens=self.config.max_tokens) - - options = get_sampling_options(sampling_params) - if "repeat_penalty" in options: - options["repetition_penalty"] = options["repeat_penalty"] - del options["repeat_penalty"] - - return VLLMSamplingParams(**options) - async def unregister_model(self, model_id: str) -> None: - pass + """ + Callback that is called when the server removes an inference endpoint + from an inference provider. + + The semantics of this callback are not clear. How should model_id + be interpreted? What happens to pending requests? + + :param model_id: Undocumented string parameter + + :returns: Nothing, at least according to the spec + """ + raise NotImplementedError() + + ########################################################################### + # METHODS INHERITED FROM Inference INTERFACE async def completion( self, @@ -147,100 +439,264 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> CompletionResponse | CompletionResponseStreamChunk: - raise NotImplementedError("Completion not implemented for vLLM") - - async def chat_completion( - self, - model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: - if sampling_params is None: - sampling_params = SamplingParams() - assert self.engine is not None - - request = ChatCompletionRequest( - model=model_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - stream=stream, - logprobs=logprobs, - tool_config=tool_config, - ) - - log.info("Sampling params: %s", sampling_params) - request_id = _random_uuid() - - prompt = await chat_completion_request_to_prompt(request, self.config.model) - vllm_sampling_params = self._sampling_params(request.sampling_params) - results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id) - if stream: - return self._stream_chat_completion(request, results_generator) - else: - return await self._nonstream_chat_completion(request, results_generator) - - async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, results_generator: AsyncGenerator - ) -> ChatCompletionResponse: - outputs = [o async for o in results_generator] - final_output = outputs[-1] - - assert final_output is not None - outputs = final_output.outputs - finish_reason = outputs[-1].stop_reason - choice = OpenAICompatCompletionChoice( - finish_reason=finish_reason, - text="".join([output.text for output in outputs]), - ) - response = OpenAICompatCompletionResponse( - choices=[choice], - ) - return process_chat_completion_response(response, request) - - async def _stream_chat_completion( - self, request: ChatCompletionRequest, results_generator: AsyncGenerator - ) -> AsyncGenerator: - tokenizer = Tokenizer.get_instance() - - async def _generate_and_convert_to_openai_compat(): - cur = [] - async for chunk in results_generator: - if not chunk.outputs: - log.warning("Empty chunk received") - continue - - output = chunk.outputs[-1] - - new_tokens = output.token_ids[len(cur) :] - text = tokenizer.decode(new_tokens) - cur.extend(new_tokens) - choice = OpenAICompatCompletionChoice( - finish_reason=output.finish_reason, - text=text, - ) - yield OpenAICompatCompletionResponse( - choices=[choice], - ) - - stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + raise NotImplementedError() async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: List[InterleavedContent], # type: ignore ) -> EmbeddingsResponse: raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], # type: ignore + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, # type: ignore + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + # model_id: str, + # messages: List[Message], # type: ignore + # sampling_params: Optional[SamplingParams] = SamplingParams(), + # tools: Optional[List[ToolDefinition]] = None, + # tool_choice: Optional[ToolChoice] = ToolChoice.auto, + # tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + # response_format: Optional[ResponseFormat] = None, + # stream: Optional[bool] = False, + # logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + sampling_params = sampling_params or SamplingParams() + if model_id not in self.model_ids: + raise ValueError( + f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" + ) + + # Arguments to the vLLM call must be packaged as a ChatCompletionRequest + # dataclass. + # Note that this dataclass has the same name as a similar dataclass in + # Llama Stack. + converted_messages = [ + await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages + ] + converted_sampling_params = _convert_sampling_params(sampling_params, response_format) + converted_tools = _convert_tools(tools) + + # Llama will try to use built-in tools with no tool catalog, so don't enable + # tool choice unless at least one tool is enabled. + converted_tool_choice = "none" + if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0: + converted_tool_choice = "auto" + + # TODO: Figure out what to do with the tool_prompt_format argument + # TODO: Convert logprobs argument + + chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest( + model=self.resolved_model_id, + messages=converted_messages, + tools=converted_tools, + tool_choice=converted_tool_choice, + stream=stream, + # tool_prompt_format=tool_prompt_format, + # logprobs=logprobs, + ) + + # vLLM's OpenAI-compatible APIs take sampling parameters as multiple + # keyword args instead of a vLLM SamplingParams object. Copy over + # all the parts that we currently convert from Llama Stack format. + for param_name in [ + "max_tokens", + "temperature", + "top_p", + "top_k", + "repetition_penalty", + ]: + setattr( + chat_completion_request, + param_name, + getattr(converted_sampling_params, param_name), + ) + + # Guided decoding parameters are further broken out + if converted_sampling_params.guided_decoding is not None: + g = converted_sampling_params.guided_decoding + chat_completion_request.guided_json = g.json + chat_completion_request.guided_regex = g.regex + chat_completion_request.guided_grammar = g.grammar + + _info(f"Converted request: {chat_completion_request}") + + vllm_result = await self.chat.create_chat_completion(chat_completion_request) + _info(f"Result from vLLM: {vllm_result}") + if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse): + raise ValueError(f"Error from vLLM layer: {vllm_result}") + + # Return type depends on "stream" argument + if stream: + if not isinstance(vllm_result, AsyncGenerator): + raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") + # vLLM client returns a stream of strings, which need to be parsed. + # Stream comes in the form of an async generator + return self._convert_streaming_results(vllm_result) + else: + if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse): + raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call") + return self._convert_non_streaming_results(vllm_result) + + ########################################################################### + # INTERNAL METHODS + + def _convert_non_streaming_results( + self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse + ) -> ChatCompletionResponse: + """ + Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible + API into an equivalent Llama Stack object. + + The result from vLLM's non-streaming API is a dataclass with + the same name as the Llama Stack ChatCompletionResponse dataclass, + but with more and different field names. We ignore the fields that + aren't currently present in the Llama Stack dataclass. + """ + + # There may be multiple responses, but we can only pass through the + # first one. + if len(vllm_result.choices) == 0: + raise ValueError("Don't know how to convert response object without any responses") + vllm_message = vllm_result.choices[0].message + + converted_message = CompletionMessage( + role=vllm_message.role, + # Llama Stack API won't accept None for content field. + content=("" if vllm_message.content is None else vllm_message.content), + stop_reason=_convert_finish_reason(vllm_result.choices[0].finish_reason), + tool_calls=[ + ToolCall( + call_id=t.id, + tool_name=t.function.name, + # vLLM function args come back as a string. Llama Stack expects JSON. + arguments=json.loads(t.function.arguments), + ) + for t in vllm_message.tool_calls + ], + ) + + # TODO: Convert logprobs + + _info(f"Converted message: {converted_message}") + + return ChatCompletionResponse( + completion_message=converted_message, + ) + + async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator: + """ + Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible + API into a second async iterator that returns Llama Stack objects. + + :param vllm_result: Stream of strings that need to be parsed + """ + # Tool calls come in pieces, but Llama Stack expects them in bigger + # chunks. We build up those chunks and output them at the end. + # This data structure holds the current set of partial tool calls. + index_to_tool_call: Dict[int, Dict] = dict() + + # The Llama Stack event stream must always start with a start event. + # Use an empty one to simplify logic below + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), + stop_reason=None, + ) + ) + + converted_stop_reason = None + async for chunk_str in vllm_result: + # Due to OpenAI compatibility, each event in the stream + # will start with "data: " and end with "\n\n". + _prefix = "data: " + _suffix = "\n\n" + if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix): + raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'") + + # In between the "data: " and newlines is an event record + data_str = chunk_str[len(_prefix) : -len(_suffix)] + + # The end of the stream is indicated with "[DONE]" + if data_str == "[DONE]": + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=converted_stop_reason, + ) + ) + return + + # Anything that is not "[DONE]" should be a JSON record + parsed_chunk = json.loads(data_str) + + # print(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") + + # The result may contain multiple completions, but Llama Stack APIs + # only support returning one. + first_choice = parsed_chunk["choices"][0] + converted_stop_reason = _convert_finish_reason(first_choice["finish_reason"]) + delta_record = first_choice["delta"] + + if "content" in delta_record: + # Text delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text=delta_record["content"]), + stop_reason=converted_stop_reason, + ) + ) + elif "tool_calls" in delta_record: + # Tool call(s). Llama Stack APIs do not have a clear way to return + # partial tool calls, so buffer until we get a "tool calls" stop reason + for tc in delta_record["tool_calls"]: + index = tc["index"] + if index not in index_to_tool_call: + # First time this tool call is showing up + index_to_tool_call[index] = dict() + tool_call = index_to_tool_call[index] + if "id" in tc: + tool_call["call_id"] = tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_call["tool_name"] = tc["function"]["name"] + if "arguments" in tc["function"]: + # Arguments comes in as pieces of a string + if "arguments_str" not in tool_call: + tool_call["arguments_str"] = "" + tool_call["arguments_str"] += tc["function"]["arguments"] + else: + raise ValueError(f"Don't know how to parse event delta: {delta_record}") + + if first_choice["finish_reason"] == "tool_calls": + # Special OpenAI code for "tool calls complete". + # Output the buffered tool calls. Llama Stack requires a separate + # event per tool call. + for tool_call_record in index_to_tool_call.values(): + # Arguments come in as a string. Parse the completed string. + tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"]) + del tool_call_record["arguments_str"] + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta(content=tool_call_record, parse_status="succeeded"), + stop_reason=converted_stop_reason, + ) + ) + + # If we get here, we've lost the connection with the vLLM event stream + # before it ended normally. + raise ValueError("vLLM event stream ended without [DONE] message.")