mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: chat completion with more than one choice (#2288)
Some checks failed
Integration Tests / test-matrix (http, inference) (push) Failing after 13s
Integration Tests / test-matrix (library, datasets) (push) Failing after 12s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 1m33s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 8s
Integration Tests / test-matrix (library, agents) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 13s
Integration Tests / test-matrix (library, scoring) (push) Failing after 10s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 13s
Integration Tests / test-matrix (library, inference) (push) Failing after 11s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 11s
Integration Tests / test-matrix (http, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 3m18s
Some checks failed
Integration Tests / test-matrix (http, inference) (push) Failing after 13s
Integration Tests / test-matrix (library, datasets) (push) Failing after 12s
Integration Tests / test-matrix (library, providers) (push) Failing after 9s
Unit Tests / unit-tests (3.10) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 1m33s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 8s
Integration Tests / test-matrix (library, agents) (push) Failing after 11s
Integration Tests / test-matrix (http, providers) (push) Failing after 13s
Integration Tests / test-matrix (library, scoring) (push) Failing after 10s
Unit Tests / unit-tests (3.13) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 13s
Integration Tests / test-matrix (library, inference) (push) Failing after 11s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 9s
Update ReadTheDocs / update-readthedocs (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 11s
Integration Tests / test-matrix (http, inspect) (push) Failing after 10s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, post_training) (push) Failing after 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 8s
Pre-commit / pre-commit (push) Successful in 3m18s
# What does this PR do? Fix a bug in openai_compat where choices are not indexed correctly. ## Test Plan Added a new test. Rerun the failed inference_store tests: llama stack run fireworks --image-type conda pytest -s -v tests/integration/ --stack-config http://localhost:8321 -k 'test_inference_store' --text-model meta-llama/Llama-3.3-70B-Instruct --count 10
This commit is contained in:
parent
1d46f3102e
commit
0b695538af
2 changed files with 53 additions and 11 deletions
|
@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
|||
assert expected.lower() in "".join(streamed_content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:streaming_01",
|
||||
"inference:chat_completion:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
|
||||
provider = provider_from_model(client_with_models, text_model_id)
|
||||
if provider.provider_type == "remote::ollama":
|
||||
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
|
||||
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = compat_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
timeout=120, # Increase timeout to 2 minutes for large conversation history,
|
||||
n=2,
|
||||
)
|
||||
streamed_content = {}
|
||||
for chunk in response:
|
||||
for choice in chunk.choices:
|
||||
if choice.delta.content:
|
||||
streamed_content[choice.index] = (
|
||||
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
|
||||
)
|
||||
assert len(streamed_content) == 2
|
||||
for i, content in streamed_content.items():
|
||||
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stream",
|
||||
[
|
||||
|
@ -231,7 +268,6 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
|||
False,
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
|
||||
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
client = openai_client
|
||||
|
@ -254,7 +290,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
|||
for chunk in response:
|
||||
if response_id is None:
|
||||
response_id = chunk.id
|
||||
content += chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
else:
|
||||
response_id = response.id
|
||||
content = response.choices[0].message.content
|
||||
|
@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
|||
|
||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||
assert retrieved_response.id == response_id
|
||||
assert retrieved_response.input_messages[0]["content"] == message
|
||||
assert retrieved_response.choices[0].message.content == content
|
||||
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
|
||||
assert retrieved_response.choices[0].message.content == content, retrieved_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
|||
False,
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
|
||||
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
client = openai_client
|
||||
|
@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
|||
for chunk in response:
|
||||
if response_id is None:
|
||||
response_id = chunk.id
|
||||
content += chunk.choices[0].delta.content
|
||||
if delta := chunk.choices[0].delta:
|
||||
if delta.content:
|
||||
content += delta.content
|
||||
else:
|
||||
response_id = response.id
|
||||
content = response.choices[0].message.content
|
||||
|
@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
|||
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||
assert retrieved_response.id == response_id
|
||||
assert retrieved_response.input_messages[0]["content"] == message
|
||||
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
||||
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
|
||||
tool_calls = retrieved_response.choices[0].message.tool_calls
|
||||
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
||||
if tool_calls:
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
assert "tokyo" in tool_calls[0].function.arguments.lower()
|
||||
else:
|
||||
assert retrieved_response.choices[0].message.content == content
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue