mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
fix: add missing extra_body to client.chat.completions.create() call
- test requires vLLM as provider, current is skipped in GH Action - test: >export VLLM_URL="http://localhost:8000" >pytest tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_extra_body -v --stack-config="inference=remote::vllm" Signed-off-by: Wen Zhou <wenzhou@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
ea964a13ec
2 changed files with 46 additions and 0 deletions
|
@ -654,6 +654,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
extra_body: dict[str, Any] | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
@ -681,6 +682,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
return await self.client.chat.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
|
@ -209,6 +211,48 @@ def test_openai_completion_guided_choice(llama_stack_client, client_with_models,
|
|||
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
|
||||
|
||||
|
||||
def test_openai_chat_completion_extra_body(llama_stack_client, client_with_models, text_model_id):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
adapter = client_with_models._routing_table.get_provider_impl(text_model_id)
|
||||
|
||||
captured_params = None
|
||||
|
||||
async def mock_create(**params):
|
||||
nonlocal captured_params
|
||||
captured_params = params
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.choices = [AsyncMock()]
|
||||
mock_response.choices[0].message = AsyncMock()
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
return mock_response
|
||||
|
||||
# Need explicit test extra_body is passed
|
||||
with patch.object(adapter.client.chat.completions, "create", side_effect=mock_create):
|
||||
response = llama_stack_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi how are you?",
|
||||
}
|
||||
],
|
||||
stream=False,
|
||||
extra_body={"chat_template_kwargs": {"thinking": True}},
|
||||
)
|
||||
assert captured_params is not None
|
||||
assert "extra_body" in captured_params
|
||||
assert captured_params["extra_body"]["chat_template_kwargs"]["thinking"] is True
|
||||
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
assert choice.message.content is not None
|
||||
assert len(choice.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue