fix: chat completion with more than one choice (#2288)

# What does this PR do?
Fix a bug in openai_compat where choices are not indexed correctly.

## Test Plan
Added a new test.

Rerun the failed inference_store tests:
llama stack run fireworks --image-type conda
pytest -s -v tests/integration/ --stack-config http://localhost:8321 -k
'test_inference_store' --text-model meta-llama/Llama-3.3-70B-Instruct
--count 10
This commit is contained in:
ehhuang 2025-05-27 15:39:15 -07:00 committed by GitHub
parent 1d46f3102e
commit 0b695538af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 11 deletions

View file

@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
): ):
id = f"chatcmpl-{uuid.uuid4()}" id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses: for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response response = await outstanding_response
i = 0
async for chunk in response: async for chunk in response:
event = chunk.event event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model, model=model,
object="chat.completion.chunk", object="chat.completion.chunk",
) )
i = i + 1
async def _process_non_stream_response( async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

View file

@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
assert expected.lower() in "".join(streamed_content) assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
timeout=120, # Increase timeout to 2 minutes for large conversation history,
n=2,
)
streamed_content = {}
for chunk in response:
for choice in chunk.choices:
if choice.delta.content:
streamed_content[choice.index] = (
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
)
assert len(streamed_content) == 2
for i, content in streamed_content.items():
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stream", "stream",
[ [
@ -231,7 +268,6 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
False, False,
], ],
) )
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
def test_inference_store(openai_client, client_with_models, text_model_id, stream): def test_inference_store(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client client = openai_client
@ -254,6 +290,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
for chunk in response: for chunk in response:
if response_id is None: if response_id is None:
response_id = chunk.id response_id = chunk.id
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content content += chunk.choices[0].delta.content
else: else:
response_id = response.id response_id = response.id
@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content assert retrieved_response.choices[0].message.content == content, retrieved_response
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False, False,
], ],
) )
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client client = openai_client
@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
for chunk in response: for chunk in response:
if response_id is None: if response_id is None:
response_id = chunk.id response_id = chunk.id
content += chunk.choices[0].delta.content if delta := chunk.choices[0].delta:
if delta.content:
content += delta.content
else: else:
response_id = response.id response_id = response.id
content = response.choices[0].message.content content = response.choices[0].message.content
@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather" tool_calls = retrieved_response.choices[0].message.tool_calls
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}' # sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls:
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
assert "tokyo" in tool_calls[0].function.arguments.lower()
else:
assert retrieved_response.choices[0].message.content == content