feat(router.py): add load balancing for async transcription calls

This commit is contained in:
Krrish Dholakia 2024-03-08 13:58:15 -08:00
parent 6b1049217e
commit ae54b398d2
2 changed files with 80 additions and 1 deletions

View file

@ -9,7 +9,7 @@
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO
import random, threading, time, traceback, uuid
import litellm, openai
from litellm.caching import RedisCache, InMemoryCache, DualCache
@ -633,6 +633,84 @@ class Router:
self.fail_calls[model_name] += 1
raise e
async def atranscription(self, file: BinaryIO, model: str, **kwargs):
try:
kwargs["model"] = model
kwargs["file"] = file
kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _atranscription(self, file: BinaryIO, model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atranscription(
**{
**data,
"file": file,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def amoderation(self, model: str, input: str, **kwargs):
try:
kwargs["model"] = model