diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index 0f8e8d38c..748871b4e 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -10,39 +10,26 @@ import uuid from typing import Any from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ( - CompletionMessage, - InterleavedTextMedia, - Message, - StopReason, - ToolChoice, - ToolDefinition, - ToolPromptFormat, -) +from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams -from llama_stack.apis.inference import ChatCompletionRequest, Inference +from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.inference.inference import ( - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionResponse, - CompletionResponseStreamChunk, - EmbeddingsResponse, - LogProbConfig, - ToolCallDelta, - ToolCallParseStatus, -) from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, + chat_completion_request_to_prompt, ) + from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + process_chat_completion_response, + process_chat_completion_stream_response, +) from .config import VLLMConfig @@ -72,10 +59,10 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: if sampling_params.repetition_penalty > 0: kwargs["repetition_penalty"] = sampling_params.repetition_penalty - return SamplingParams().from_optional(**kwargs) + return SamplingParams(**kwargs) -class VLLMInferenceImpl(Inference, ModelRegistryHelper): +class VLLMInferenceImpl(ModelRegistryHelper, Inference): """Inference implementation for vLLM.""" HF_MODEL_MAPPINGS = { @@ -148,7 +135,7 @@ class VLLMInferenceImpl(Inference, ModelRegistryHelper): if self.engine: self.engine.shutdown_background_loop() - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, @@ -157,17 +144,16 @@ class VLLMInferenceImpl(Inference, ModelRegistryHelper): logprobs: LogProbConfig | None = None, ) -> CompletionResponse | CompletionResponseStreamChunk: log.info("vLLM completion") - messages = [Message(role="user", content=content)] - async for result in self.chat_completion( + messages = [UserMessage(content=content)] + return self.chat_completion( model=model, messages=messages, sampling_params=sampling_params, stream=stream, logprobs=logprobs, - ): - yield result + ) - async def chat_completion( + def chat_completion( self, model: str, messages: list[Message], @@ -194,159 +180,59 @@ class VLLMInferenceImpl(Inference, ModelRegistryHelper): ) log.info("Sampling params: %s", sampling_params) - vllm_sampling_params = _vllm_sampling_params(sampling_params) - - messages = augment_messages_for_tools(request) - log.info("Augmented messages: %s", messages) - prompt = "".join([str(message.content) for message in messages]) - request_id = _random_uuid() + + prompt = chat_completion_request_to_prompt(request, self.formatter) + vllm_sampling_params = _vllm_sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id ) - - if not stream: - # Non-streaming case - final_output = None - stop_reason = None - async for request_output in results_generator: - final_output = request_output - if stop_reason is None and request_output.outputs: - reason = request_output.outputs[-1].stop_reason - if reason == "stop": - stop_reason = StopReason.end_of_turn - elif reason == "length": - stop_reason = StopReason.out_of_tokens - - if not stop_reason: - stop_reason = StopReason.end_of_message - - if final_output: - response = "".join([output.text for output in final_output.outputs]) - yield ChatCompletionResponse( - completion_message=CompletionMessage( - content=response, - stop_reason=stop_reason, - ), - logprobs=None, - ) + if stream: + return self._stream_chat_completion(request, results_generator) else: - # Streaming case - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) + return self._nonstream_chat_completion(request, results_generator) - buffer = "" - last_chunk = "" - ipython = False - stop_reason = None + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, results_generator: AsyncGenerator + ) -> ChatCompletionResponse: + outputs = [o async for o in results_generator] + final_output = outputs[-1] + assert final_output is not None + outputs = final_output.outputs + finish_reason = outputs[-1].stop_reason + choice = OpenAICompatCompletionChoice( + finish_reason=finish_reason, + text="".join([output.text for output in outputs]), + ) + response = OpenAICompatCompletionResponse( + choices=[choice], + ) + return process_chat_completion_response(request, response, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, results_generator: AsyncGenerator + ) -> AsyncGenerator: + async def _generate_and_convert_to_openai_compat(): async for chunk in results_generator: if not chunk.outputs: log.warning("Empty chunk received") continue - if chunk.outputs[-1].stop_reason: - reason = chunk.outputs[-1].stop_reason - if stop_reason is None and reason == "stop": - stop_reason = StopReason.end_of_turn - elif stop_reason is None and reason == "length": - stop_reason = StopReason.out_of_tokens - break - text = "".join([output.text for output in chunk.outputs]) - - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - last_chunk_len = len(last_chunk) - last_chunk = text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text[last_chunk_len:], - stop_reason=stop_reason, - ) - ) - - if not stop_reason: - stop_reason = StopReason.end_of_message - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) + choice = OpenAICompatCompletionChoice( + finish_reason=chunk.outputs[-1].stop_reason, + text=text, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], ) - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk async def embeddings( self, model: str, contents: list[InterleavedTextMedia]