forked from phoenix-oss/llama-stack-mirror
Fix incorrect handling of chat completion endpoint in remote::vLLM (#951)
# What does this PR do? Fixes https://github.com/meta-llama/llama-stack/issues/949. ## Test Plan Verified that the correct chat completion endpoint is called after the change. Llama Stack server: ``` INFO: ::1:32838 - "POST /v1/inference/chat-completion HTTP/1.1" 200 OK 18:36:28.187 [END] /v1/inference/chat-completion [StatusCode.OK] (1276.12ms) ``` vLLM server: ``` INFO: ::1:36866 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` ```bash LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -s -v tests/client-sdk/inference/test_inference.py -k "test_image_chat_completion_base64 or test_image_chat_completion_non_streaming or test_image_chat_completion_streaming" ================================================================== test session starts =================================================================== platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 16 items / 12 deselected / 4 selected tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64[meta-llama/Llama-3.2-11B-Vision-Instruct-url] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64[meta-llama/Llama-3.2-11B-Vision-Instruct-data] PASSED ``` Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
09ed0e9c9f
commit
0a0ee5ca96
1 changed files with 5 additions and 21 deletions
|
@ -46,7 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
|
@ -142,10 +141,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = client.chat.completions.create(**params)
|
||||
else:
|
||||
r = client.completions.create(**params)
|
||||
r = client.chat.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
|
@ -154,10 +150,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
s = client.chat.completions.create(**params)
|
||||
else:
|
||||
s = client.completions.create(**params)
|
||||
s = client.chat.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
@ -200,20 +193,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||
else:
|
||||
assert not media_present, "vLLM does not support media for Completion requests"
|
||||
assert not request_has_media(request), "vLLM does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(
|
||||
request,
|
||||
self.formatter,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue