From 0a0ee5ca96ab30d893db5b629435b6e90fee39fe Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 6 Feb 2025 13:45:19 -0500 Subject: [PATCH] Fix incorrect handling of chat completion endpoint in remote::vLLM (#951) # What does this PR do? Fixes https://github.com/meta-llama/llama-stack/issues/949. ## Test Plan Verified that the correct chat completion endpoint is called after the change. Llama Stack server: ``` INFO: ::1:32838 - "POST /v1/inference/chat-completion HTTP/1.1" 200 OK 18:36:28.187 [END] /v1/inference/chat-completion [StatusCode.OK] (1276.12ms) ``` vLLM server: ``` INFO: ::1:36866 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` ```bash LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -s -v tests/client-sdk/inference/test_inference.py -k "test_image_chat_completion_base64 or test_image_chat_completion_non_streaming or test_image_chat_completion_streaming" ================================================================== test session starts =================================================================== platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 16 items / 12 deselected / 4 selected tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64[meta-llama/Llama-3.2-11B-Vision-Instruct-url] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64[meta-llama/Llama-3.2-11B-Vision-Instruct-data] PASSED ``` Signed-off-by: Yuan Tang --- .../providers/remote/inference/vllm/vllm.py | 26 ++++--------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bd3375baf..482e6fa97 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -46,7 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import ( process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, completion_request_to_prompt, content_has_media, interleaved_content_as_str, @@ -142,10 +141,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, request: ChatCompletionRequest, client: OpenAI ) -> ChatCompletionResponse: params = await self._get_params(request) - if "messages" in params: - r = client.chat.completions.create(**params) - else: - r = client.completions.create(**params) + r = client.chat.completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: @@ -154,10 +150,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async # generator so this wrapper is not necessary? async def _to_async_generator(): - if "messages" in params: - s = client.chat.completions.create(**params) - else: - s = client.completions.create(**params) + s = client.chat.completions.create(**params) for chunk in s: yield chunk @@ -200,20 +193,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): options["max_tokens"] = self.config.max_tokens input_dict = {} - media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): - if media_present: - input_dict["messages"] = [ - await convert_message_to_openai_dict(m, download=True) for m in request.messages - ] - else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, - self.register_helper.get_llama_model(request.model), - self.formatter, - ) + input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] else: - assert not media_present, "vLLM does not support media for Completion requests" + assert not request_has_media(request), "vLLM does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt( request, self.formatter,