From fcb87faa3697fd69e2d4cbb2daa616757500e660 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Fri, 24 Jan 2025 10:56:47 -0800 Subject: [PATCH] Add completion API --- .../providers/inline/inference/vllm/vllm.py | 110 ++++++++++++++++-- 1 file changed, 98 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e40ab5fdc..4e1fc853d 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -8,6 +8,7 @@ import datetime import json import logging import re +import uuid from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union # These vLLM modules contain names that overlap with Llama Stack names, @@ -114,6 +115,10 @@ def _info(msg: str): # logger.info(msg) +def _random_uuid_str() -> str: + return str(uuid.uuid4().hex) + + def _merge_context_into_content(message: Message) -> Message: # type: ignore """ Merge the ``context`` field of a Llama Stack ``Message`` object into @@ -169,7 +174,10 @@ def _response_format_to_guided_decoding_params( layer of vLLM. """ if response_format is None: - return vllm.sampling_params.GuidedDecodingParams() + # As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() + # returns an invalid value that crashes the executor on some code + # paths. Use ``None`` instead. + return None # Llama Stack currently implements fewer types of constrained # decoding than vLLM does. Translate the types that exist and @@ -440,7 +448,93 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: - raise NotImplementedError() + if model_id not in self.model_ids: + raise ValueError( + f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" + ) + if not isinstance(content, str): + raise NotImplementedError("Multimodal input not currently supported") + if sampling_params is None: + sampling_params = SamplingParams() + if logprobs is not None: + raise NotImplementedError("logprobs argument not currently implemented") + + converted_sampling_params = _convert_sampling_params(sampling_params, response_format) + + if stream: + return self._streaming_completion(content, converted_sampling_params) + else: + streaming_result = None + async for streaming_result in self._streaming_completion(content, converted_sampling_params): + pass + return CompletionResponse( + content=streaming_result.delta, + stop_reason=streaming_result.stop_reason, + logprobs=streaming_result.logprobs, + ) + + async def _streaming_completion( + self, content: str, sampling_params: vllm.SamplingParams + ) -> AsyncIterator[CompletionResponseStreamChunk]: + """Internal implementation of :func:`completion()` API for the streaming + case. Assumes that arguments have been validated upstream. + + :param content: Must be a string + :param sampling_params: Paramters from public API's ``response_format`` + and ``sampling_params`` arguments, converted to VLLM format + """ + # We run agains the vLLM generate() call directly instead of using the + # OpenAI-compatible layer, because doing so simplifies the code here. + + # The vLLM engine requires a unique identifier for each call to generate() + request_id = _random_uuid_str() + + # The vLLM generate() API is streaming-only and returns an async generator. + # The generator returns objects of type vllm.RequestOutput + results_generator = self.engine.generate(content, sampling_params, request_id) + + # Need to know the model's EOS token ID for the conversion code below. + # This information is buried pretty deeply. + eos_token_id = self.engine.engine.tokenizer.tokenizer.eos_token_id + + request_output: vllm.RequestOutput = None + async for request_output in results_generator: + # Check for weird inference failures + if request_output.outputs is None or len(request_output.outputs) == 0: + # This case also should never happen + raise ValueError("Inference produced empty result") + + # If we get here, then request_output contains the final output of the + # generate() call. There should be one or more output chunks. + completion_string = "".join([output.text for output in request_output.outputs]) + + # The final output chunk should be labeled with the reason that the + # overall generate() call completed. + stop_reason_str = request_output.outputs[-1].stop_reason + if stop_reason_str is None: + stop_reason = None # Still going + elif stop_reason_str == "stop": + stop_reason = StopReason.end_of_turn + elif stop_reason_str == "length": + stop_reason = StopReason.out_of_tokens + else: + raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'") + + # _info(f"completion string: {completion_string}") + # _info(f"stop reason: {stop_reason_str}") + # _info(f"completion tokens: {completion_tokens}") + + # vLLM's protocol outputs the stop token, then sets end of message + # on the next step for some reason. + if request_output.outputs[-1].token_ids[-1] == eos_token_id: + stop_reason = StopReason.end_of_message + + yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason) + + # Llama Stack requires that the last chunk have a stop reason, but + # vLLM doesn't always provide one if it runs out of tokens. + if stop_reason is None: + yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens) async def embeddings( self, @@ -460,21 +554,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - # model_id: str, - # messages: List[Message], # type: ignore - # sampling_params: Optional[SamplingParams] = SamplingParams(), - # tools: Optional[List[ToolDefinition]] = None, - # tool_choice: Optional[ToolChoice] = ToolChoice.auto, - # tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, - # response_format: Optional[ResponseFormat] = None, - # stream: Optional[bool] = False, - # logprobs: Optional[LogProbConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: - sampling_params = sampling_params or SamplingParams() if model_id not in self.model_ids: raise ValueError( f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" ) + if logprobs is not None: + raise NotImplementedError("logprobs argument not currently implemented") # Arguments to the vLLM call must be packaged as a ChatCompletionRequest # dataclass.