From 5d54c2ee70a7a3031cdaf80922db2eed908aefdf Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Tue, 4 Feb 2025 13:09:49 -0800 Subject: [PATCH] Use Llama Stack template when streaming --- .../providers/inline/inference/vllm/vllm.py | 157 +++++++++++------- 1 file changed, 98 insertions(+), 59 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 986eb8068..06abd0290 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -68,7 +68,12 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ModelsProtocolPrivate, ) -from llama_stack.providers.utils.inference.openai_compat import get_stop_reason +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + get_stop_reason, + process_chat_completion_stream_response, +) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, ) @@ -92,7 +97,7 @@ logger = logging.getLogger(__name__) # Adjust logging parameters from Python code. This appears to be the standard way to control # logging in Llama Stack. -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) stderr_handler = logging.StreamHandler() stderr_handler.setFormatter(logging.Formatter("%(asctime)s: %(filename)s [%(levelname)s] %(message)s")) logger.addHandler(stderr_handler) @@ -266,7 +271,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): else: # if resolved_llama_model is None # Not a Llama model name. Pass the model id through to vLLM's loader resolved_model_id = model.provider_model_id - is_meta_llama_model = True + is_meta_llama_model = False logger.info(f"Model id {model} resolved to {resolved_model_id}") @@ -282,6 +287,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.model_ids.add(model.model_id) return model + if is_meta_llama_model: + logger.info(f"Model {resolved_model_id} is a Meta Llama model.") self.is_meta_llama_model = is_meta_llama_model logger.info(f"Preloading model: {resolved_model_id}") @@ -443,11 +450,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): f"Routing {self.resolved_model_id} chat completion through " f"Llama Stack's templating layer instead of vLLM's." ) - if stream: - # return self._chat_completion_for_meta_llama_streaming(request) - pass # Use vLLM until the above method is implemented. - else: - return await self._chat_completion_for_meta_llama_non_streaming(request) + return await self._chat_completion_for_meta_llama(request) logger.debug(f"{self.resolved_model_id} is not a Meta Llama model") @@ -523,15 +526,19 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # The final output chunk should be labeled with the reason that the overall generate() # call completed. - stop_reason_str = output.stop_reason - if stop_reason_str is None: + logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}") + if output.stop_reason is None: stop_reason = None # Still going - elif stop_reason_str == "stop": + elif output.stop_reason == "stop": stop_reason = StopReason.end_of_turn - elif stop_reason_str == "length": + elif output.stop_reason == "length": stop_reason = StopReason.out_of_tokens + elif isinstance(output.stop_reason, int): + # If the model config specifies multiple end-of-sequence tokens, then vLLM + # will return the token ID of the EOS token in the stop_reason field. + stop_reason = StopReason.end_of_turn else: - raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'") + raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'") # vLLM's protocol outputs the stop token, then sets end of message on the next step for # some reason. @@ -549,49 +556,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=logprobs, ) - async def _chat_completion_for_meta_llama_non_streaming( - self, request: ChatCompletionRequest - ) -> ChatCompletionResponse: - """ - Subroutine that routes chat completions for Meta Llama models through Llama Stack's - chat template instead of using vLLM's version of that template. The Llama Stack version - of the chat template currently produces more reliable outputs. - - Once vLLM's support for Meta Llama models has matured more, we should consider routing - Meta Llama requests through the vLLM chat completions API instead of using this method. - """ - - logger.debug("Routing request through Llama Stack templates.") - - formatter = ChatFormat(Tokenizer.get_instance()) - - # Note that this function call modifies `request` in place. - prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter) - - model_id = list(self.model_ids)[0] # Any model ID will do here - completion_response = await self.completion( - model_id=model_id, - content=prompt, - sampling_params=request.sampling_params, - response_format=request.response_format, - stream=False, - logprobs=request.logprobs, - ) - if not isinstance(completion_response, CompletionResponse): # Sanity check - raise TypeError(f"Unexpected type '{type(completion_response)}' for completion response.") - - raw_message = formatter.decode_assistant_message_from_content( - completion_response.content, completion_response.stop_reason - ) - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, - tool_calls=raw_message.tool_calls, - ), - logprobs=completion_response.logprobs, - ) - def _convert_non_streaming_results( self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse ) -> ChatCompletionResponse: @@ -634,7 +598,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): completion_message=converted_message, ) - def _chat_completion_for_meta_llama_streaming(self, request: ChatCompletionRequest) -> AsyncIterator: + async def _chat_completion_for_meta_llama( + self, request: ChatCompletionRequest + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: """ Subroutine that routes chat completions for Meta Llama models through Llama Stack's chat template instead of using vLLM's version of that template. The Llama Stack version @@ -643,9 +609,82 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): Once vLLM's support for Meta Llama models has matured more, we should consider routing Meta Llama requests through the vLLM chat completions API instead of using this method. """ - logger.debug("Routing streaming request through Llama Stack templates.") + formatter = ChatFormat(Tokenizer.get_instance()) - raise NotImplementedError() + # Note that this function call modifies `request` in place. + prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id, formatter) + + model_id = list(self.model_ids)[0] # Any model ID will do here + completion_response_or_iterator = await self.completion( + model_id=model_id, + content=prompt, + sampling_params=request.sampling_params, + response_format=request.response_format, + stream=request.stream, + logprobs=request.logprobs, + ) + + if request.stream: + if not isinstance(completion_response_or_iterator, AsyncIterator): + raise TypeError( + f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request." + ) + return self._chat_completion_for_meta_llama_streaming(formatter, completion_response_or_iterator) + + # elsif not request.stream: + if not isinstance(completion_response_or_iterator, CompletionResponse): + raise TypeError( + f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request." + ) + completion_response: CompletionResponse = completion_response_or_iterator + raw_message = formatter.decode_assistant_message_from_content( + completion_response.content, completion_response.stop_reason + ) + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=raw_message.content, + stop_reason=raw_message.stop_reason, + tool_calls=raw_message.tool_calls, + ), + logprobs=completion_response.logprobs, + ) + + async def _chat_completion_for_meta_llama_streaming( + self, formatter: ChatFormat, results_iterator: AsyncIterator + ) -> AsyncIterator: + """ + Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate + method to keep asyncio happy. + """ + + # Convert to OpenAI format, then use shared code to convert to Llama Stack format. + async def _generate_and_convert_to_openai_compat(): + chunk: CompletionResponseStreamChunk # Make Pylance happy + last_text_len = 0 + async for chunk in results_iterator: + if chunk.stop_reason == StopReason.end_of_turn: + finish_reason = "stop" + elif chunk.stop_reason == StopReason.end_of_message: + finish_reason = "eos" + elif chunk.stop_reason == StopReason.out_of_tokens: + finish_reason = "length" + else: + finish_reason = None + + # Convert delta back to an actual delta + text_delta = chunk.delta[last_text_len:] + last_text_len = len(chunk.delta) + + logger.debug(f"{text_delta=}; {finish_reason=}") + + yield OpenAICompatCompletionResponse( + choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)] + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_chat_completion_stream_response(stream, formatter): + logger.debug(f"Returning chunk: {chunk}") + yield chunk async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator: """