diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index cc0000528..049f06fdb 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin: outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" - for outstanding_response in outstanding_responses: + for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response - i = 0 async for chunk in response: event = chunk.event finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) @@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin: model=model, object="chat.completion.chunk", ) - i = i + 1 async def _process_non_stream_response( self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 0efcfda2e..2cd76a23d 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex assert expected.lower() in "".join(streamed_content) +@pytest.mark.parametrize( + "test_case", + [ + "inference:chat_completion:streaming_01", + "inference:chat_completion:streaming_02", + ], +) +def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) + + provider = provider_from_model(client_with_models, text_model_id) + if provider.provider_type == "remote::ollama": + pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.") + + tc = TestCase(test_case) + question = tc["question"] + expected = tc["expected"] + + response = compat_client.chat.completions.create( + model=text_model_id, + messages=[{"role": "user", "content": question}], + stream=True, + timeout=120, # Increase timeout to 2 minutes for large conversation history, + n=2, + ) + streamed_content = {} + for chunk in response: + for choice in chunk.choices: + if choice.delta.content: + streamed_content[choice.index] = ( + streamed_content.get(choice.index, "") + choice.delta.content.lower().strip() + ) + assert len(streamed_content) == 2 + for i, content in streamed_content.items(): + assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}" + + @pytest.mark.parametrize( "stream", [ @@ -231,7 +268,6 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex False, ], ) -@pytest.mark.skip(reason="Very flaky, keeps failing on CI") def test_inference_store(openai_client, client_with_models, text_model_id, stream): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) client = openai_client @@ -254,7 +290,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea for chunk in response: if response_id is None: response_id = chunk.id - content += chunk.choices[0].delta.content + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content else: response_id = response.id content = response.choices[0].message.content @@ -264,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id - assert retrieved_response.input_messages[0]["content"] == message - assert retrieved_response.choices[0].message.content == content + assert retrieved_response.input_messages[0]["content"] == message, retrieved_response + assert retrieved_response.choices[0].message.content == content, retrieved_response @pytest.mark.parametrize( @@ -275,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea False, ], ) -@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI") def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) client = openai_client @@ -313,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode for chunk in response: if response_id is None: response_id = chunk.id - content += chunk.choices[0].delta.content + if delta := chunk.choices[0].delta: + if delta.content: + content += delta.content else: response_id = response.id content = response.choices[0].message.content @@ -324,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id assert retrieved_response.input_messages[0]["content"] == message - assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather" - assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}' + tool_calls = retrieved_response.choices[0].message.tool_calls + # sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called + if tool_calls: + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "get_weather" + assert "tokyo" in tool_calls[0].function.arguments.lower() + else: + assert retrieved_response.choices[0].message.content == content