fix: chat completion with more than one choice (#2288)
Some checks failed
Integration Tests / test-matrix (http, inference) (push) Failing after 13s
Integration Tests / test-matrix (library, datasets) (push) Failing after 12s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 1m33s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 8s
Integration Tests / test-matrix (library, agents) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 13s
Integration Tests / test-matrix (library, scoring) (push) Failing after 10s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 13s
Integration Tests / test-matrix (library, inference) (push) Failing after 11s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 11s
Integration Tests / test-matrix (http, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 3m18s

# What does this PR do?
Fix a bug in openai_compat where choices are not indexed correctly.

## Test Plan
Added a new test.

Rerun the failed inference_store tests:
llama stack run fireworks --image-type conda
pytest -s -v tests/integration/ --stack-config http://localhost:8321 -k
'test_inference_store' --text-model meta-llama/Llama-3.3-70B-Instruct
--count 10
This commit is contained in:
ehhuang 2025-05-27 15:39:15 -07:00 committed by GitHub
parent 1d46f3102e
commit 0b695538af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 11 deletions

View file

@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
timeout=120, # Increase timeout to 2 minutes for large conversation history,
n=2,
)
streamed_content = {}
for chunk in response:
for choice in chunk.choices:
if choice.delta.content:
streamed_content[choice.index] = (
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
)
assert len(streamed_content) == 2
for i, content in streamed_content.items():
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
@pytest.mark.parametrize(
"stream",
[
@ -231,7 +268,6 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
False,
],
)
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
@ -254,7 +290,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
for chunk in response:
if response_id is None:
response_id = chunk.id
content += chunk.choices[0].delta.content
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
else:
response_id = response.id
content = response.choices[0].message.content
@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.content == content
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content, retrieved_response
@pytest.mark.parametrize(
@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False,
],
)
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
for chunk in response:
if response_id is None:
response_id = chunk.id
content += chunk.choices[0].delta.content
if delta := chunk.choices[0].delta:
if delta.content:
content += delta.content
else:
response_id = response.id
content = response.choices[0].message.content
@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
tool_calls = retrieved_response.choices[0].message.tool_calls
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls:
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
assert "tokyo" in tool_calls[0].function.arguments.lower()
else:
assert retrieved_response.choices[0].message.content == content