From ecd182eb6aed5e59ca294c77c4b32ff1bcb9118f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 28 May 2024 19:44:41 -0700 Subject: [PATCH] feat(router.py): support fastest response batch completion call returns fastest response. cancels others. --- litellm/router.py | 73 ++++++++++++++++++- litellm/tests/test_router_batch_completion.py | 46 ++++++++---- 2 files changed, 102 insertions(+), 17 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index e2ebea37f..631360da6 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -356,7 +356,8 @@ class Router: raise ValueError(f"Item '{fallback_dict}' is not a dictionary.") if len(fallback_dict) != 1: raise ValueError( - f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.") + f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys." + ) def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): if routing_strategy == "least-busy": @@ -737,6 +738,76 @@ class Router: response = await asyncio.gather(*_tasks) return response + # fmt: off + + @overload + async def abatch_completion_fastest_response( + self, models: List[str], messages: List[Dict[str, str]], stream: Literal[True], **kwargs + ) -> CustomStreamWrapper: + ... + + + + @overload + async def abatch_completion_fastest_response( + self, models: List[str], messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs + ) -> ModelResponse: + ... + + # fmt: on + + async def abatch_completion_fastest_response( + self, + models: List[str], + messages: List[Dict[str, str]], + stream: bool = False, + **kwargs, + ): + """Send 1 completion call to many models: Return Fastest Response.""" + + async def _async_completion_no_exceptions( + model: str, messages: List[Dict[str, str]], **kwargs + ): + """ + Wrapper around self.async_completion that catches exceptions and returns them as a result + """ + try: + return await self.acompletion(model=model, messages=messages, **kwargs) + except Exception as e: + return e + + _tasks = [] + pending_tasks = [] # type: ignore + + async def check_response(task): + nonlocal pending_tasks + result = await task + if isinstance(result, (ModelResponse, CustomStreamWrapper)): + # If a desired response is received, cancel all other pending tasks + for t in pending_tasks: + t.cancel() + return result + else: + try: + pending_tasks.remove(task) + except Exception as e: + pass + + for model in models: + task = asyncio.create_task( + _async_completion_no_exceptions( + model=model, messages=messages, **kwargs + ) + ) + task.add_done_callback(check_response) + _tasks.append(task) + pending_tasks.append(task) + + responses = await asyncio.gather(*_tasks, return_exceptions=True) + if isinstance(responses[0], Exception): + raise responses[0] + return responses[0] # return first value from list + def image_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model diff --git a/litellm/tests/test_router_batch_completion.py b/litellm/tests/test_router_batch_completion.py index f2873b18d..219881dcb 100644 --- a/litellm/tests/test_router_batch_completion.py +++ b/litellm/tests/test_router_batch_completion.py @@ -19,8 +19,9 @@ import os, httpx load_dotenv() +@pytest.mark.parametrize("mode", ["all_responses", "fastest_response"]) @pytest.mark.asyncio -async def test_batch_completion_multiple_models(): +async def test_batch_completion_multiple_models(mode): litellm.set_verbose = True router = litellm.Router( @@ -40,21 +41,34 @@ async def test_batch_completion_multiple_models(): ] ) - response = await router.abatch_completion( - models=["gpt-3.5-turbo", "groq-llama"], - messages=[ - {"role": "user", "content": "is litellm becoming a better product ?"} - ], - max_tokens=15, - ) + if mode == "all_responses": + response = await router.abatch_completion( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + {"role": "user", "content": "is litellm becoming a better product ?"} + ], + max_tokens=15, + ) - print(response) - assert len(response) == 2 + print(response) + assert len(response) == 2 - models_in_responses = [] - for individual_response in response: - _model = individual_response["model"] - models_in_responses.append(_model) + models_in_responses = [] + for individual_response in response: + _model = individual_response["model"] + models_in_responses.append(_model) - # assert both models are different - assert models_in_responses[0] != models_in_responses[1] + # assert both models are different + assert models_in_responses[0] != models_in_responses[1] + elif mode == "fastest_response": + from openai.types.chat.chat_completion import ChatCompletion + + response = await router.abatch_completion_fastest_response( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + {"role": "user", "content": "is litellm becoming a better product ?"} + ], + max_tokens=15, + ) + + ChatCompletion.model_validate(response.model_dump(), strict=True)