diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bd3375baf..482e6fa97 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -46,7 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, completion_request_to_prompt, content_has_media, interleaved_content_as_str, @@ -142,10 +141,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, request: ChatCompletionRequest, client: OpenAI ) -> ChatCompletionResponse: params = await self._get_params(request) - if "messages" in params: - r = client.chat.completions.create(**params) - else: - r = client.completions.create(**params) + r = client.chat.completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: @@ -154,10 +150,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async # generator so this wrapper is not necessary? async def _to_async_generator(): - if "messages" in params: - s = client.chat.completions.create(**params) - else: - s = client.completions.create(**params) + s = client.chat.completions.create(**params) for chunk in s: yield chunk @@ -200,20 +193,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): options["max_tokens"] = self.config.max_tokens input_dict = {} - media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): - if media_present: - input_dict["messages"] = [ - await convert_message_to_openai_dict(m, download=True) for m in request.messages - ] - else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, - self.register_helper.get_llama_model(request.model), - self.formatter, - ) + input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] else: - assert not media_present, "vLLM does not support media for Completion requests" + assert not request_has_media(request), "vLLM does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt( request, self.formatter,