diff --git a/litellm/main.py b/litellm/main.py index f1a745fccd..a04dba4eca 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3313,6 +3313,7 @@ def image_generation( ##### Transcription ####################### +@client async def atranscription(*args, **kwargs): """ Calls openai + azure whisper endpoints. diff --git a/litellm/router.py b/litellm/router.py index d4c0be8622..10f7058b3e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -9,7 +9,7 @@ import copy, httpx from datetime import datetime -from typing import Dict, List, Optional, Union, Literal, Any +from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO import random, threading, time, traceback, uuid import litellm, openai from litellm.caching import RedisCache, InMemoryCache, DualCache @@ -633,6 +633,84 @@ class Router: self.fail_calls[model_name] += 1 raise e + async def atranscription(self, file: BinaryIO, model: str, **kwargs): + try: + kwargs["model"] = model + kwargs["file"] = file + kwargs["original_function"] = self._atranscription + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) + kwargs.setdefault("metadata", {}).update({"model_group": model}) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _atranscription(self, file: BinaryIO, model: str, **kwargs): + try: + verbose_router_logger.debug( + f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" + ) + deployment = self.get_available_deployment( + model=model, + messages=[{"role": "user", "content": "prompt"}], + specific_deployment=kwargs.pop("specific_deployment", None), + ) + kwargs.setdefault("metadata", {}).update( + { + "deployment": deployment["litellm_params"]["model"], + "model_info": deployment.get("model_info", {}), + } + ) + kwargs["model_info"] = deployment.get("model_info", {}) + data = deployment["litellm_params"].copy() + model_name = data["model"] + for k, v in self.default_litellm_params.items(): + if ( + k not in kwargs + ): # prioritize model-specific params > default router params + kwargs[k] = v + elif k == "metadata": + kwargs[k].update(v) + + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None + else: + model_client = potential_model_client + + self.total_calls[model_name] += 1 + response = await litellm.atranscription( + **{ + **data, + "file": file, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m" + ) + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e + async def amoderation(self, model: str, input: str, **kwargs): try: kwargs["model"] = model