diff --git a/litellm/router.py b/litellm/router.py index 1ed6854cd..3715ec26c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -771,13 +771,13 @@ class Router: models = [m.strip() for m in model.split(",")] async def _async_completion_no_exceptions( - model: str, messages: List[Dict[str, str]], **kwargs: Any + model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any ) -> Union[ModelResponse, CustomStreamWrapper, Exception]: """ Wrapper around self.acompletion that catches exceptions and returns them as a result """ try: - return await self.acompletion(model=model, messages=messages, **kwargs) + return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore except asyncio.CancelledError: verbose_router_logger.debug( "Received 'task.cancel'. Cancelling call w/ model={}.".format(model) @@ -813,7 +813,7 @@ class Router: for model in models: task = asyncio.create_task( _async_completion_no_exceptions( - model=model, messages=messages, **kwargs + model=model, messages=messages, stream=stream, **kwargs ) ) pending_tasks.append(task) diff --git a/litellm/tests/test_router_batch_completion.py b/litellm/tests/test_router_batch_completion.py index c74892814..82fe102e2 100644 --- a/litellm/tests/test_router_batch_completion.py +++ b/litellm/tests/test_router_batch_completion.py @@ -114,3 +114,39 @@ async def test_batch_completion_fastest_response_unit_test(): assert response._hidden_params["model_id"] == "2" assert response.choices[0].message.content == "This is a fake response" print(f"response: {response}") + + +@pytest.mark.asyncio +async def test_batch_completion_fastest_response_streaming(): + litellm.set_verbose = True + + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + { + "model_name": "groq-llama", + "litellm_params": { + "model": "groq/llama3-8b-8192", + }, + }, + ] + ) + + from openai.types.chat.chat_completion_chunk import ChatCompletionChunk + + response = await router.abatch_completion_fastest_response( + model="gpt-3.5-turbo, groq-llama", + messages=[ + {"role": "user", "content": "is litellm becoming a better product ?"} + ], + max_tokens=15, + stream=True, + ) + + async for chunk in response: + ChatCompletionChunk.model_validate(chunk.model_dump(), strict=True)