diff --git a/litellm/tests/test_router_batch_completion.py b/litellm/tests/test_router_batch_completion.py index 82fe102e2..26329792e 100644 --- a/litellm/tests/test_router_batch_completion.py +++ b/litellm/tests/test_router_batch_completion.py @@ -150,3 +150,47 @@ async def test_batch_completion_fastest_response_streaming(): async for chunk in response: ChatCompletionChunk.model_validate(chunk.model_dump(), strict=True) + + +@pytest.mark.asyncio +async def test_batch_completion_multiple_models_multiple_messages(): + litellm.set_verbose = True + + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + { + "model_name": "groq-llama", + "litellm_params": { + "model": "groq/llama3-8b-8192", + }, + }, + ] + ) + + response = await router.abatch_completion( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + [{"role": "user", "content": "is litellm becoming a better product ?"}], + [{"role": "user", "content": "who is this"}], + ], + max_tokens=15, + ) + + print("response from batches =", response) + assert len(response) == 2 + assert len(response[0]) == 2 + assert isinstance(response[0][0], litellm.ModelResponse) + + # models_in_responses = [] + # for individual_response in response: + # _model = individual_response["model"] + # models_in_responses.append(_model) + + # # assert both models are different + # assert models_in_responses[0] != models_in_responses[1]