forked from phoenix-oss/llama-stack-mirror
fix: chat completion with more than one choice (#2288)
# What does this PR do? Fix a bug in openai_compat where choices are not indexed correctly. ## Test Plan Added a new test. Rerun the failed inference_store tests: llama stack run fireworks --image-type conda pytest -s -v tests/integration/ --stack-config http://localhost:8321 -k 'test_inference_store' --text-model meta-llama/Llama-3.3-70B-Instruct --count 10
This commit is contained in:
parent
1d46f3102e
commit
0b695538af
2 changed files with 53 additions and 11 deletions
|
@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||||
):
|
):
|
||||||
id = f"chatcmpl-{uuid.uuid4()}"
|
id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
for outstanding_response in outstanding_responses:
|
for i, outstanding_response in enumerate(outstanding_responses):
|
||||||
response = await outstanding_response
|
response = await outstanding_response
|
||||||
i = 0
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||||
|
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
)
|
)
|
||||||
i = i + 1
|
|
||||||
|
|
||||||
async def _process_non_stream_response(
|
async def _process_non_stream_response(
|
||||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||||
|
|
|
@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
||||||
assert expected.lower() in "".join(streamed_content)
|
assert expected.lower() in "".join(streamed_content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:chat_completion:streaming_01",
|
||||||
|
"inference:chat_completion:streaming_02",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
provider = provider_from_model(client_with_models, text_model_id)
|
||||||
|
if provider.provider_type == "remote::ollama":
|
||||||
|
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
|
||||||
|
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
question = tc["question"]
|
||||||
|
expected = tc["expected"]
|
||||||
|
|
||||||
|
response = compat_client.chat.completions.create(
|
||||||
|
model=text_model_id,
|
||||||
|
messages=[{"role": "user", "content": question}],
|
||||||
|
stream=True,
|
||||||
|
timeout=120, # Increase timeout to 2 minutes for large conversation history,
|
||||||
|
n=2,
|
||||||
|
)
|
||||||
|
streamed_content = {}
|
||||||
|
for chunk in response:
|
||||||
|
for choice in chunk.choices:
|
||||||
|
if choice.delta.content:
|
||||||
|
streamed_content[choice.index] = (
|
||||||
|
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
|
||||||
|
)
|
||||||
|
assert len(streamed_content) == 2
|
||||||
|
for i, content in streamed_content.items():
|
||||||
|
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"stream",
|
"stream",
|
||||||
[
|
[
|
||||||
|
@ -231,7 +268,6 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
||||||
False,
|
False,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.skip(reason="Very flaky, keeps failing on CI")
|
|
||||||
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
|
def test_inference_store(openai_client, client_with_models, text_model_id, stream):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
client = openai_client
|
client = openai_client
|
||||||
|
@ -254,6 +290,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if response_id is None:
|
if response_id is None:
|
||||||
response_id = chunk.id
|
response_id = chunk.id
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
content += chunk.choices[0].delta.content
|
content += chunk.choices[0].delta.content
|
||||||
else:
|
else:
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
|
@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
|
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.input_messages[0]["content"] == message
|
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
|
||||||
assert retrieved_response.choices[0].message.content == content
|
assert retrieved_response.choices[0].message.content == content, retrieved_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
False,
|
False,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
|
|
||||||
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
client = openai_client
|
client = openai_client
|
||||||
|
@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if response_id is None:
|
if response_id is None:
|
||||||
response_id = chunk.id
|
response_id = chunk.id
|
||||||
content += chunk.choices[0].delta.content
|
if delta := chunk.choices[0].delta:
|
||||||
|
if delta.content:
|
||||||
|
content += delta.content
|
||||||
else:
|
else:
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.input_messages[0]["content"] == message
|
assert retrieved_response.input_messages[0]["content"] == message
|
||||||
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
tool_calls = retrieved_response.choices[0].message.tool_calls
|
||||||
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
|
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
||||||
|
if tool_calls:
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].function.name == "get_weather"
|
||||||
|
assert "tokyo" in tool_calls[0].function.arguments.lower()
|
||||||
|
else:
|
||||||
|
assert retrieved_response.choices[0].message.content == content
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue