Fix incorrect handling of chat completion endpoint in remote::vLLM (#951)

# What does this PR do?

Fixes https://github.com/meta-llama/llama-stack/issues/949.


## Test Plan

Verified that the correct chat completion endpoint is called after the
change.

Llama Stack server:
```
INFO:     ::1:32838 - "POST /v1/inference/chat-completion HTTP/1.1" 200 OK
18:36:28.187 [END] /v1/inference/chat-completion [StatusCode.OK] (1276.12ms)

```

vLLM server:
```
INFO:     ::1:36866 - "POST /v1/chat/completions HTTP/1.1" 200 OK
```

```bash
LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -s -v tests/client-sdk/inference/test_inference.py -k "test_image_chat_completion_base64 or test_image_chat_completion_non_streaming or test_image_chat_completion_streaming"
================================================================== test session starts ===================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/yutang/repos/llama-stack
configfile: pyproject.toml
plugins: anyio-4.8.0
collected 16 items / 12 deselected / 4 selected                                                                                                          

tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64[meta-llama/Llama-3.2-11B-Vision-Instruct-url] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64[meta-llama/Llama-3.2-11B-Vision-Instruct-data] PASSED
```

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-06 13:45:19 -05:00 committed by GitHub
parent 09ed0e9c9f
commit 0a0ee5ca96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,7 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
@ -142,10 +141,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = client.chat.completions.create(**params)
else:
r = client.completions.create(**params)
r = client.chat.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
@ -154,10 +150,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
# generator so this wrapper is not necessary?
async def _to_async_generator():
if "messages" in params:
s = client.chat.completions.create(**params)
else:
s = client.completions.create(**params)
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk
@ -200,20 +193,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
options["max_tokens"] = self.config.max_tokens
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
self.register_helper.get_llama_model(request.model),
self.formatter,
)
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not media_present, "vLLM does not support media for Completion requests"
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(
request,
self.formatter,