fix: chat completion with more than one choice (#2288)

# What does this PR do?
Fix a bug in openai_compat where choices are not indexed correctly.

## Test Plan
Added a new test.

Rerun the failed inference_store tests:
llama stack run fireworks --image-type conda
pytest -s -v tests/integration/ --stack-config http://localhost:8321 -k
'test_inference_store' --text-model meta-llama/Llama-3.3-70B-Instruct
--count 10
This commit is contained in:
ehhuang 2025-05-27 15:39:15 -07:00 committed by GitHub
parent 1d46f3102e
commit 0b695538af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 11 deletions

View file

@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response
i = 0
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model,
object="chat.completion.chunk",
)
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

View file

@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
timeout=120, # Increase timeout to 2 minutes for large conversation history,
n=2,
)
streamed_content = {}
for chunk in response:
for choice in chunk.choices:
if choice.delta.content:
streamed_content[choice.index] = (
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
)
assert len(streamed_content) == 2
for i, content in streamed_content.items():
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
@pytest.mark.parametrize(
"stream",
[
@ -231,7 +268,6 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
False,
],
)
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
@ -254,6 +290,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
for chunk in response:
if response_id is None:
response_id = chunk.id
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
else:
response_id = response.id
@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.content == content
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content, retrieved_response
@pytest.mark.parametrize(
@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False,
],
)
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
for chunk in response:
if response_id is None:
response_id = chunk.id
content += chunk.choices[0].delta.content
if delta := chunk.choices[0].delta:
if delta.content:
content += delta.content
else:
response_id = response.id
content = response.choices[0].message.content
@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
tool_calls = retrieved_response.choices[0].message.tool_calls
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls:
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
assert "tokyo" in tool_calls[0].function.arguments.lower()
else:
assert retrieved_response.choices[0].message.content == content