diff --git a/litellm/router.py b/litellm/router.py index f0d94908e9..7396dab208 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -605,6 +605,33 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def abatch_completion( + self, models: List[str], messages: List[Dict[str, str]], **kwargs + ): + + async def _async_completion_no_exceptions( + model: str, messages: List[Dict[str, str]], **kwargs + ): + """ + Wrapper around self.async_completion that catches exceptions and returns them as a result + """ + try: + return await self.acompletion(model=model, messages=messages, **kwargs) + except Exception as e: + return e + + _tasks = [] + for model in models: + # add each task but if the task fails + _tasks.append( + _async_completion_no_exceptions( + model=model, messages=messages, **kwargs + ) + ) + + response = await asyncio.gather(*_tasks) + return response + def image_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model