litellm-mirror/litellm/router.py
2024-06-01 21:19:32 -07:00

4549 lines
186 KiB
Python

# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache
import datetime as datetime_og
import logging, asyncio
import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.llms.custom_httpx.azure_dall_e_2 import (
CustomHTTPTransport,
AsyncCustomHTTPTransport,
)
from litellm.utils import (
ModelResponse,
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
_is_region_eu,
)
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.router import (
Deployment,
ModelInfo,
LiteLLM_Params,
RouterErrors,
updateDeployment,
updateLiteLLMParams,
RetryPolicy,
AllowedFailsPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.types.llms.openai import (
AsyncCursorPage,
Assistant,
Thread,
Attachment,
OpenAIMessage,
Run,
AssistantToolParam,
)
from litellm.scheduler import Scheduler, FlowItem
from typing import Iterable
class Router:
model_names: List = []
cache_responses: Optional[bool] = False
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
tenacity = None
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None
def __init__(
self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
cache_responses: Optional[bool] = False,
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
caching_groups: Optional[
List[tuple]
] = None, # if you want to cache across model groups
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## SCHEDULER ##
polling_interval: Optional[float] = None,
## RELIABILITY ##
num_retries: Optional[int] = None,
timeout: Optional[float] = None,
default_litellm_params: Optional[
dict
] = None, # default params for Router.chat.completion.create
default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO",
default_fallbacks: Optional[
List[str]
] = None, # generic fallbacks, works across all deployments
fallbacks: List = [],
context_window_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[
RetryPolicy
] = None, # set custom retries for different exceptions
model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
] = {}, # set custom retry policies based on model group
allowed_fails: Optional[
int
] = None, # Number of times a deployment can failbefore being added to cooldown
allowed_fails_policy: Optional[
AllowedFailsPolicy
] = None, # set custom allowed fails policy
cooldown_time: Optional[
float
] = None, # (seconds) time to cooldown a deployment after failure
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
"usage-based-routing",
"latency-based-routing",
"cost-based-routing",
"usage-based-routing-v2",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None,
alerting_config: Optional[AlertingConfig] = None,
) -> None:
"""
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
Args:
model_list (Optional[list]): List of models to be used. Defaults to None.
redis_url (Optional[str]): URL of the Redis server. Defaults to None.
redis_host (Optional[str]): Hostname of the Redis server. Defaults to None.
redis_port (Optional[int]): Port of the Redis server. Defaults to None.
redis_password (Optional[str]): Password of the Redis server. Defaults to None.
cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False.
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
timeout (Optional[float]): Timeout for requests. Defaults to None.
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
set_verbose (bool): Flag to set verbose mode. Defaults to False.
debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO".
fallbacks (List): List of fallback options. Defaults to [].
context_window_fallbacks (List): List of context window fallback options. Defaults to [].
enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt
model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}.
retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0.
allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None.
cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None.
Returns:
Router: An instance of the litellm.Router class.
Example Usage:
```python
from litellm import Router
model_list = [
{
"model_name": "azure-gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/<your-deployment-name-1>",
"api_key": <your-api-key>,
"api_version": <your-api-version>,
"api_base": <your-api-base>
},
},
{
"model_name": "azure-gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/<your-deployment-name-2>",
"api_key": <your-api-key>,
"api_version": <your-api-version>,
"api_base": <your-api-base>
},
},
{
"model_name": "openai-gpt-3.5-turbo", # model alias
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": <your-api-key>,
},
]
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
```
"""
if semaphore:
self.semaphore = semaphore
self.set_verbose = set_verbose
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
if self.set_verbose == True:
if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG)
self.deployment_names: List = (
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### SCHEDULER ###
self.scheduler = Scheduler(polling_interval=polling_interval)
### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None
cache_config = {}
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
cache_type = "redis"
if redis_url is not None:
cache_config["url"] = redis_url
if redis_host is not None:
cache_config["host"] = redis_host
if redis_port is not None:
cache_config["port"] = str(redis_port) # type: ignore
if redis_password is not None:
cache_config["password"] = redis_password
# Add additional key-value pairs from cache_kwargs
cache_config.update(cache_kwargs)
redis_cache = RedisCache(**cache_config)
if cache_responses:
if litellm.cache is None:
# the cache can be initialized on the proxy server. We should not overwrite it
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
self.cache_responses = cache_responses
self.cache = DualCache(
redis_cache=redis_cache, in_memory_cache=InMemoryCache()
) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
self.default_deployment = None # use this to track the users default deployment, when they want to use model = *
self.default_max_parallel_requests = default_max_parallel_requests
if model_list is not None:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list # type: ignore
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
else:
self.model_list: List = (
[]
) # initialize an empty list - to allow _add_deployment and delete_deployment to work
self.allowed_fails = allowed_fails or litellm.allowed_fails
self.cooldown_time = cooldown_time or 60
self.failed_calls = (
InMemoryCache()
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
if num_retries is not None:
self.num_retries = num_retries
elif litellm.num_retries is not None:
self.num_retries = litellm.num_retries
else:
self.num_retries = openai.DEFAULT_MAX_RETRIES
self.timeout = timeout or litellm.request_timeout
self.retry_after = retry_after
self.routing_strategy = routing_strategy
## SETTING FALLBACKS ##
### validate if it's set + in correct format
_fallbacks = fallbacks or litellm.fallbacks
self.validate_fallbacks(fallback_param=_fallbacks)
### set fallbacks
self.fallbacks = _fallbacks
if default_fallbacks is not None or litellm.default_fallbacks is not None:
_fallbacks = default_fallbacks or litellm.default_fallbacks
if self.fallbacks is not None:
self.fallbacks.append({"*": _fallbacks})
else:
self.fallbacks = [{"*": _fallbacks}]
self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks
)
self.total_calls: defaultdict = defaultdict(
int
) # dict to store total calls made to each model
self.fail_calls: defaultdict = defaultdict(
int
) # dict to store fail_calls made to each model
self.success_calls: defaultdict = defaultdict(
int
) # dict to store success_calls made to each model
self.previous_models: List = (
[]
) # list to store failed calls (passed in as metadata to next call)
self.model_group_alias: dict = (
model_group_alias or {}
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
# make Router.chat.completions.create compatible for openai.chat.completions.create
default_litellm_params = default_litellm_params or {}
self.chat = litellm.Chat(params=default_litellm_params, router_obj=self)
# default litellm args
self.default_litellm_params = default_litellm_params
self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0)
self.default_litellm_params.setdefault("metadata", {}).update(
{"caching_groups": caching_groups}
)
self.deployment_stats: dict = {} # used for debugging load balancing
"""
deployment_stats = {
"122999-2828282-277:
{
"model": "gpt-3",
"api_base": "http://localhost:4000",
"num_requests": 20,
"avg_latency": 0.001,
"num_failures": 0,
"num_successes": 20
}
}
"""
### ROUTING SETUP ###
self.routing_strategy_init(
routing_strategy=routing_strategy,
routing_strategy_args=routing_strategy_args,
)
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
else:
litellm.failure_callback = [self.deployment_callback_on_failure]
print( # noqa
f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}"
) # noqa
self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = retry_policy
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
def validate_fallbacks(self, fallback_param: Optional[List]):
if fallback_param is None:
return
for fallback_dict in fallback_param:
if not isinstance(fallback_dict, dict):
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
if len(fallback_dict) != 1:
raise ValueError(
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(
router_cache=self.cache, model_list=self.model_list
)
## add callback
if isinstance(litellm.input_callback, list):
litellm.input_callback.append(self.leastbusy_logger) # type: ignore
else:
litellm.input_callback = [self.leastbusy_logger] # type: ignore
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.leastbusy_logger) # type: ignore
elif routing_strategy == "usage-based-routing":
self.lowesttpm_logger = LowestTPMLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger) # type: ignore
elif routing_strategy == "usage-based-routing-v2":
self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore
elif routing_strategy == "latency-based-routing":
self.lowestlatency_logger = LowestLatencyLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
routing_args=routing_strategy_args,
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestlatency_logger) # type: ignore
elif routing_strategy == "cost-based-routing":
self.lowestcost_logger = LowestCostLoggingHandler(
router_cache=self.cache,
model_list=self.model_list,
routing_args={},
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.append(self.lowestcost_logger) # type: ignore
def print_deployment(self, deployment: dict):
"""
returns a copy of the deployment with the api key masked
"""
try:
_deployment_copy = copy.deepcopy(deployment)
litellm_params: dict = _deployment_copy["litellm_params"]
if "api_key" in litellm_params:
litellm_params["api_key"] = litellm_params["api_key"][:2] + "*" * 10
return _deployment_copy
except Exception as e:
verbose_router_logger.debug(
f"Error occurred while printing deployment - {str(e)}"
)
raise e
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
def completion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
"""
Example usage:
response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}]
"""
try:
verbose_router_logger.debug(f"router.completion(model={model},..)")
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
timeout = kwargs.get("request_timeout", self.timeout)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _completion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
model_name = None
try:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
model=model,
messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"api_base": deployment.get("litellm_params", {}).get("api_base"),
"model_info": deployment.get("model_info", {}),
}
)
data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {})
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.completion(
**{
**data,
"messages": messages,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
verbose_router_logger.info(
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.completion(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
raise e
# fmt: off
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
) -> Union[CustomStreamWrapper, ModelResponse]:
...
# fmt: on
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
):
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
"""
- Get an available deployment
- call it with a semaphore over the call
- semaphore specific to it's rpm
- in the semaphore, make a check against it's local rpm before running
"""
model_name = None
try:
verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
# debug how often this deployment picked
self._track_deployment_metrics(deployment=deployment)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
timeout = (
data.get(
"timeout", None
) # timeout set on litellm_params for this deployment
or self.timeout # timeout set on router
or kwargs.get(
"timeout", None
) # this uses default_litellm_params when nothing is set
)
_response = litellm.acompletion(
**{
**data,
"messages": messages,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await _response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
)
# debug how often this deployment picked
self._track_deployment_metrics(deployment=deployment, response=response)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def abatch_completion(
self,
models: List[str],
messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
**kwargs,
):
"""
Async Batch Completion. Used for 2 scenarios:
1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this
2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this
Example Request for 1 request to N models:
```
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
```
Example Request for N requests to M models:
```
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
[{"role": "user", "content": "is litellm becoming a better product ?"}],
[{"role": "user", "content": "who is this"}],
],
)
```
"""
############## Helpers for async completion ##################
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
async def _async_completion_no_exceptions_return_idx(
model: str,
messages: List[Dict[str, str]],
idx: int, # index of message this response corresponds to
**kwargs,
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return (
await self.acompletion(model=model, messages=messages, **kwargs),
idx,
)
except Exception as e:
return e, idx
############## Helpers for async completion ##################
if isinstance(messages, list) and all(isinstance(m, dict) for m in messages):
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore
response = await asyncio.gather(*_tasks)
return response
elif isinstance(messages, list) and all(isinstance(m, list) for m in messages):
_tasks = []
for idx, message in enumerate(messages):
for model in models:
# Request Number X, Model Number Y
_tasks.append(
_async_completion_no_exceptions_return_idx(
model=model, idx=idx, messages=message, **kwargs # type: ignore
)
)
responses = await asyncio.gather(*_tasks)
final_responses: List[List[Any]] = [[] for _ in range(len(messages))]
for response in responses:
if isinstance(response, tuple):
final_responses[response[1]].append(response[0])
else:
final_responses[0].append(response)
return final_responses
async def abatch_completion_one_model_multiple_requests(
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
):
"""
Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
Use this for sending multiple requests to 1 model
Args:
model (List[str]): model group
messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request
**kwargs: additional kwargs
Usage:
response = await self.abatch_completion_one_model_multiple_requests(
model="gpt-3.5-turbo",
messages=[
[{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}],
[{"role": "user", "content": "hello good mornign"}],
]
)
"""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
for message_request in messages:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=message_request, **kwargs
)
)
response = await asyncio.gather(*_tasks)
return response
# fmt: off
@overload
async def abatch_completion_fastest_response(
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def abatch_completion_fastest_response(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
async def abatch_completion_fastest_response(
self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
**kwargs,
):
"""
model - List of comma-separated model names. E.g. model="gpt-4, gpt-3.5-turbo"
Returns fastest response from list of model names. OpenAI-compatible endpoint.
"""
models = [m.strip() for m in model.split(",")]
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any
) -> Union[ModelResponse, CustomStreamWrapper, Exception]:
"""
Wrapper around self.acompletion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore
except asyncio.CancelledError:
verbose_router_logger.debug(
"Received 'task.cancel'. Cancelling call w/ model={}.".format(model)
)
raise
except Exception as e:
return e
pending_tasks = [] # type: ignore
async def check_response(task: asyncio.Task):
nonlocal pending_tasks
try:
result = await task
if isinstance(result, (ModelResponse, CustomStreamWrapper)):
verbose_router_logger.debug(
"Received successful response. Cancelling other LLM API calls."
)
# If a desired response is received, cancel all other pending tasks
for t in pending_tasks:
t.cancel()
return result
except Exception:
# Ignore exceptions, let the loop handle them
pass
finally:
# Remove the task from pending tasks if it finishes
try:
pending_tasks.remove(task)
except KeyError:
pass
for model in models:
task = asyncio.create_task(
_async_completion_no_exceptions(
model=model, messages=messages, stream=stream, **kwargs
)
)
pending_tasks.append(task)
# Await the first task to complete successfully
while pending_tasks:
done, pending_tasks = await asyncio.wait( # type: ignore
pending_tasks, return_when=asyncio.FIRST_COMPLETED
)
for completed_task in done:
result = await check_response(completed_task)
if result is not None:
# Return the first successful result
result._hidden_params["fastest_response_batch_completion"] = True
return result
# If we exit the loop without returning, all tasks failed
raise Exception("All tasks failed")
### SCHEDULER ###
# fmt: off
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
# fmt: on
async def schedule_acompletion(
self,
model: str,
messages: List[Dict[str, str]],
priority: int,
stream=False,
**kwargs,
):
### FLOW ITEM ###
_request_id = str(uuid.uuid4())
item = FlowItem(
priority=priority, # 👈 SET PRIORITY FOR REQUEST
request_id=_request_id, # 👈 SET REQUEST ID
model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
)
### [fin] ###
## ADDS REQUEST TO QUEUE ##
await self.scheduler.add_request(request=item)
## POLL QUEUE
end_time = time.time() + self.timeout
curr_time = time.time()
poll_interval = self.scheduler.polling_interval # poll every 3ms
make_request = False
while curr_time < end_time:
_healthy_deployments = await self._async_get_healthy_deployments(
model=model
)
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id,
model_name=item.model_name,
health_deployments=_healthy_deployments,
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()
if make_request:
try:
_response = await self.acompletion(
model=model, messages=messages, stream=stream, **kwargs
)
return _response
except Exception as e:
setattr(e, "priority", priority)
raise e
else:
raise litellm.Timeout(
message="Request timed out while polling queue",
model=model,
llm_provider="openai",
)
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _image_generation(self, prompt: str, model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.image_generation(
**{
**data,
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.image_generation(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.image_generation(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def aimage_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.aimage_generation(
**{
**data,
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.aimage_generation(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def atranscription(self, file: BinaryIO, model: str, **kwargs):
"""
Example Usage:
```
from litellm import Router
client = Router(model_list = [
{
"model_name": "whisper",
"litellm_params": {
"model": "whisper-1",
},
},
])
audio_file = open("speech.mp3", "rb")
transcript = await client.atranscription(
model="whisper",
file=audio_file
)
```
"""
try:
kwargs["model"] = model
kwargs["file"] = file
kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _atranscription(self, file: BinaryIO, model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.atranscription(
**{
**data,
"file": file,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def aspeech(self, model: str, input: str, voice: str, **kwargs):
"""
Example Usage:
```
from litellm import Router
client = Router(model_list = [
{
"model_name": "tts",
"litellm_params": {
"model": "tts-1",
},
},
])
async with client.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
api_base=None,
api_key=None,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
) as response:
response.stream_to_file(speech_file_path)
```
"""
try:
kwargs["input"] = input
kwargs["voice"] = voice
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
response = await litellm.aspeech(**data, **kwargs)
return response
except Exception as e:
raise e
async def amoderation(self, model: str, input: str, **kwargs):
try:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._amoderation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _amoderation(self, model: str, input: str, **kwargs):
model_name = None
try:
verbose_router_logger.debug(
f"Inside _moderation()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
timeout = (
data.get(
"timeout", None
) # timeout set on litellm_params for this deployment
or self.timeout # timeout set on router
or kwargs.get(
"timeout", None
) # this uses default_litellm_params when nothing is set
)
response = await litellm.amoderation(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
"timeout": timeout,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
def text_completion(
self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs,
):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
model=model,
messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None),
)
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
async def atext_completion(
self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs,
):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _atext_completion(self, model: str, prompt: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": prompt}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.atext_completion(
**{
**data,
"prompt": prompt,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m"
)
if model is not None:
self.fail_calls[model] += 1
raise e
def embedding(
self,
model: str,
input: Union[str, List],
is_async: Optional[bool] = False,
**kwargs,
) -> Union[List[float], None]:
try:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._embedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _embedding(self, input: Union[str, List], model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.embedding(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def aembedding(
self,
model: str,
input: Union[str, List],
is_async: Optional[bool] = True,
**kwargs,
) -> Union[List[float], None]:
try:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._aembedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _aembedding(self, input: Union[str, List], model: str, **kwargs):
model_name = None
try:
verbose_router_logger.debug(
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.aembedding(
**{
**data,
"input": input,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.aembedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
#### ASSISTANTS API ####
async def aget_assistants(
self,
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def a_add_message(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
role=role,
content=content,
attachments=attachments,
metadata=metadata,
client=client,
**kwargs,
)
async def aget_messages(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def arun_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Run:
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
client=client,
**kwargs,
)
#### [END] ASSISTANTS API ####
async def async_function_with_fallbacks(self, *args, **kwargs):
"""
Try calling the function_with_retries
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = await self.async_function_with_retries(*args, **kwargs)
verbose_router_logger.debug(f"Async Response: {response}")
return response
except Exception as e:
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
original_exception = e
fallback_model_group = None
try:
verbose_router_logger.debug(f"Trying to fallback b/w models")
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request
raise e
if (
isinstance(e, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
):
fallback_model_group = None
for (
item
) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_retries(
*args, **kwargs
)
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response
except Exception as e:
pass
elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None:
verbose_router_logger.info(
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
)
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
verbose_router_logger.info(
f"Falling back to model_group = {mg}"
)
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_fallbacks(
*args, **kwargs
)
verbose_router_logger.info(
"Successful fallback b/w models."
)
return response
except Exception as e:
raise e
except Exception as e:
verbose_router_logger.debug(f"An exception occurred - {str(e)}")
traceback.print_exc()
raise original_exception
async def async_function_with_retries(self, *args, **kwargs):
verbose_router_logger.debug(
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
num_retries = kwargs.pop("num_retries")
verbose_router_logger.debug(
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs)
return response
except Exception as e:
current_attempt = None
original_exception = e
"""
Retry Logic
"""
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model") or "",
)
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
# sleeps for the length of the timeout
await asyncio.sleep(_timeout)
if (
self.retry_policy is not None
or self.model_group_retry_policy is not None
):
# get num_retries from retry policy
_retry_policy_retries = self.get_num_retries_from_retry_policy(
exception=original_exception, model_group=kwargs.get("model")
)
if _retry_policy_retries is not None:
num_retries = _retry_policy_retries
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}"
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs)
if inspect.iscoroutinefunction(
response
): # async errors are often returned as coroutines
response = await response
return response
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
_healthy_deployments = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
await asyncio.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries
original_exception.num_retries = current_attempt
raise original_exception
def should_retry_this_error(
self,
error: Exception,
healthy_deployments: Optional[List] = None,
context_window_fallbacks: Optional[List] = None,
regular_fallbacks: Optional[List] = None,
):
"""
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
2. raise an exception for RateLimitError if
- there are no fallbacks
- there are no healthy deployments in the same model group
"""
_num_healthy_deployments = 0
if healthy_deployments is not None and isinstance(healthy_deployments, list):
_num_healthy_deployments = len(healthy_deployments)
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(error, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
):
raise error
# Error we should only retry if there are other deployments
if isinstance(error, openai.RateLimitError) or isinstance(
error, openai.AuthenticationError
):
if (
_num_healthy_deployments <= 0
and regular_fallbacks is not None
and len(regular_fallbacks) > 0
):
raise error
return True
def function_with_fallbacks(self, *args, **kwargs):
"""
Try calling the function_with_retries
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
response = self.function_with_retries(*args, **kwargs)
return response
except Exception as e:
original_exception = e
verbose_router_logger.debug(f"An exception occurs {original_exception}")
try:
if (
hasattr(e, "status_code")
and e.status_code == 400 # type: ignore
and not isinstance(e, litellm.ContextWindowExceededError)
): # don't retry a malformed request
raise e
verbose_router_logger.debug(
f"Trying to fallback b/w models. Initial model group: {model_group}"
)
if (
isinstance(e, litellm.ContextWindowExceededError)
and context_window_fallbacks is not None
):
fallback_model_group = None
for (
item
) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
pass
elif fallbacks is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group = None
generic_fallback_idx: Optional[int] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if list(item.keys())[0] == model_group:
fallback_model_group = item[model_group]
break
elif list(item.keys())[0] == "*":
generic_fallback_idx = idx
## if none, check for generic fallback
if (
fallback_model_group is None
and generic_fallback_idx is not None
):
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
if fallback_model_group is None:
raise original_exception
for mg in fallback_model_group:
"""
Iterate through the model groups and try calling that deployment
"""
try:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs)
return response
except Exception as e:
raise e
except Exception as e:
raise e
raise original_exception
def _time_to_sleep_before_retry(
self,
e: Exception,
remaining_retries: int,
num_retries: int,
healthy_deployments: Optional[List] = None,
) -> Union[int, float]:
"""
Calculate back-off, then retry
It should instantly retry only when:
1. there are healthy deployments in the same model group
2. there are fallbacks for the completion call
"""
if (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
and len(healthy_deployments) > 0
):
return 0
if hasattr(e, "response") and hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
response_headers=e.response.headers,
min_timeout=self.retry_after,
)
else:
timeout = litellm._calculate_retry_after(
remaining_retries=remaining_retries,
max_retries=num_retries,
min_timeout=self.retry_after,
)
return timeout
def function_with_retries(self, *args, **kwargs):
"""
Try calling the model 3 times. Shuffle-between available deployments.
"""
verbose_router_logger.debug(
f"Inside function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
num_retries = kwargs.pop("num_retries")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except Exception as e:
current_attempt = None
original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
)
# raises an exception if this error should not be retries
self.should_retry_this_error(
error=e,
healthy_deployments=_healthy_deployments,
context_window_fallbacks=context_window_fallbacks,
regular_fallbacks=fallbacks,
)
# decides how long to sleep before retry
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
## LOGGING
if num_retries > 0:
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
time.sleep(_timeout)
for current_attempt in range(num_retries):
verbose_router_logger.debug(
f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}"
)
try:
# if the function call is successful, no exception will be raised and we'll break out of the loop
response = original_function(*args, **kwargs)
return response
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
_healthy_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
)
remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry(
e=e,
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
)
time.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries
original_exception.num_retries = current_attempt
raise original_exception
### HELPER FUNCTIONS
def deployment_callback_on_failure(
self,
kwargs, # kwargs to completion
completion_response, # response from completion
start_time,
end_time, # start/end time
):
try:
exception = kwargs.get("exception", None)
exception_type = type(exception)
exception_status = getattr(exception, "status_code", "")
exception_cause = getattr(exception, "__cause__", "")
exception_message = getattr(exception, "message", "")
exception_str = (
str(exception_type)
+ "Status: "
+ str(exception_status)
+ "Message: "
+ str(exception_cause)
+ str(exception_message)
+ "Full exception"
+ str(exception)
)
model_name = kwargs.get("model", None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
exception_response = getattr(exception, "response", {})
exception_headers = getattr(exception_response, "headers", None)
_time_to_cooldown = self.cooldown_time
if exception_headers is not None:
_time_to_cooldown = (
litellm.utils._get_retry_after_from_exception_header(
response_headers=exception_headers
)
)
if _time_to_cooldown is None or _time_to_cooldown < 0:
# if the response headers did not read it -> set to default cooldown time
_time_to_cooldown = self.cooldown_time
if isinstance(_model_info, dict):
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
exception_status=exception_status,
original_exception=exception,
deployment=deployment_id,
time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
except Exception as e:
raise e
def log_retry(self, kwargs: dict, e: Exception) -> dict:
"""
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
"""
try:
# Log failed model as the previous model
previous_model = {
"exception_type": type(e).__name__,
"exception_string": str(e),
}
for (
k,
v,
) in (
kwargs.items()
): # log everything in kwargs except the old previous_models value - prevent nesting
if k not in ["metadata", "messages", "original_function"]:
previous_model[k] = v
elif k == "metadata" and isinstance(v, dict):
previous_model["metadata"] = {} # type: ignore
for metadata_k, metadata_v in kwargs["metadata"].items():
if metadata_k != "previous_models":
previous_model[k][metadata_k] = metadata_v # type: ignore
# check current size of self.previous_models, if it's larger than 3, remove the first element
if len(self.previous_models) > 3:
self.previous_models.pop(0)
self.previous_models.append(previous_model)
kwargs["metadata"]["previous_models"] = self.previous_models
return kwargs
except Exception as e:
raise e
def _update_usage(self, deployment_id: str):
"""
Update deployment rpm for that minute
"""
rpm_key = deployment_id
request_count = self.cache.get_cache(key=rpm_key, local_only=True)
if request_count is None:
request_count = 1
self.cache.set_cache(
key=rpm_key, value=request_count, local_only=True, ttl=60
) # only store for 60s
else:
request_count += 1
self.cache.set_cache(
key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl
def _is_cooldown_required(self, exception_status: Union[str, int]):
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
try:
if isinstance(exception_status, str):
exception_status = int(exception_status)
if exception_status >= 400 and exception_status < 500:
if exception_status == 429:
# Cool down 429 Rate Limit Errors
return True
elif exception_status == 401:
# Cool down 401 Auth Errors
return True
elif exception_status == 408:
return True
elif exception_status == 404:
return True
else:
# Do NOT cool down all other 4XX Errors
return False
else:
# should cool down for all other errors
return True
except:
# Catch all - if any exceptions default to cooling down
return True
def _set_cooldown_deployments(
self,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
):
"""
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or
the exception is not one that should be immediately retried (e.g. 401)
"""
if deployment is None:
return
if self._is_cooldown_required(exception_status=exception_status) == False:
return
_allowed_fails = self.get_allowed_fails_from_policy(
exception=original_exception,
)
allowed_fails = _allowed_fails or self.allowed_fails
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment
# update the number of failed calls
# if it's > allowed fails
# cooldown deployment
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
verbose_router_logger.debug(
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None:
cooldown_time = time_to_cooldown
if isinstance(exception_status, str):
try:
exception_status = int(exception_status)
except Exception as e:
verbose_router_logger.debug(
"Unable to cast exception status to int {}. Defaulting to status=500.".format(
exception_status
)
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > allowed_fails or _should_retry == False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
# update value
try:
if deployment in cached_value:
pass
else:
cached_value = cached_value + [deployment]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
except:
cached_value = [deployment]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
self.send_deployment_cooldown_alert(
deployment_id=deployment,
exception_status=exception_status,
cooldown_time=cooldown_time,
)
else:
self.failed_calls.set_cache(
key=deployment, value=updated_fails, ttl=cooldown_time
)
async def _async_get_cooldown_deployments(self):
"""
Async implementation of '_get_cooldown_deployments'
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_cooldown_deployments(self):
"""
Get the list of models being cooled down for this minute
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = self._get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
async def _async_get_healthy_deployments(self, model: str):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
return []
except:
pass
unhealthy_deployments = await self._async_get_cooldown_deployments()
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
continue
else:
healthy_deployments.append(deployment)
return healthy_deployments
def routing_strategy_pre_call_checks(self, deployment: dict):
"""
Mimics 'async_routing_strategy_pre_call_checks'
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = _callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
try:
response = await _callback.async_pre_call_check(deployment)
except litellm.RateLimitError as e:
self._set_cooldown_deployments(
exception_status=e.status_code,
original_exception=e,
deployment=deployment["model_info"]["id"],
time_to_cooldown=self.cooldown_time,
)
raise e
except Exception as e:
raise e
def set_client(self, model: dict):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = self.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=self.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
model=model_name
)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai"
# remove azure prefx from model_name
model_name = model_name.replace("azure/", "")
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
is_azure_ai_studio_model == True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
# proxy support
import os
import httpx
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
http_proxy = os.getenv("HTTP_PROXY", None)
https_proxy = os.getenv("HTTPS_PROXY", None)
no_proxy = os.getenv("NO_PROXY", None)
# Create the proxies dictionary only if the environment variables are set.
sync_proxy_mounts = None
async_proxy_mounts = None
if http_proxy is not None and https_proxy is not None:
sync_proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
"https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
}
async_proxy_mounts = {
"http://": httpx.AsyncHTTPTransport(
proxy=httpx.Proxy(url=http_proxy)
),
"https://": httpx.AsyncHTTPTransport(
proxy=httpx.Proxy(url=https_proxy)
),
}
# assume no_proxy is a list of comma separated urls
if no_proxy is not None and isinstance(no_proxy, str):
no_proxy_urls = no_proxy.split(",")
for url in no_proxy_urls: # set no-proxy support for specific urls
sync_proxy_mounts[url] = None # type: ignore
async_proxy_mounts[url] = None # type: ignore
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name:
if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = {
k: v
for k, v in model["litellm_params"].items()
if k != "api_key"
}
_filtered_model = {
"model_name": model["model_name"],
"litellm_params": filtered_litellm_params,
}
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients can have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
verify=litellm.ssl_verify,
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
),
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
),
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
def _generate_model_id(self, model_group: str, litellm_params: dict):
"""
Helper function to consistently generate the same id for a deployment
- create a string from all the litellm params
- hash
- use hash as id
"""
concat_str = model_group
for k, v in litellm_params.items():
if isinstance(k, str):
concat_str += k
elif isinstance(k, dict):
concat_str += json.dumps(k)
else:
concat_str += str(k)
if isinstance(v, str):
concat_str += v
elif isinstance(v, dict):
concat_str += json.dumps(v)
else:
concat_str += str(v)
hash_object = hashlib.sha256(concat_str.encode())
return hash_object.hexdigest()
def set_model_list(self, model_list: list):
original_model_list = copy.deepcopy(model_list)
self.model_list = []
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
import os
for model in original_model_list:
_model_name = model.pop("model_name")
_litellm_params = model.pop("litellm_params")
## check if litellm params in os.environ
if isinstance(_litellm_params, dict):
for k, v in _litellm_params.items():
if isinstance(v, str) and v.startswith("os.environ/"):
_litellm_params[k] = litellm.get_secret(v)
_model_info: dict = model.pop("model_info", {})
# check if model info has id
if "id" not in _model_info:
_id = self._generate_model_id(_model_name, _litellm_params)
_model_info["id"] = _id
deployment = Deployment(
**model,
model_name=_model_name,
litellm_params=_litellm_params,
model_info=_model_info,
)
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)
self.model_list.append(model)
verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}")
self.model_names = [m["model_name"] for m in model_list]
def _add_deployment(self, deployment: Deployment) -> Deployment:
import os
#### DEPLOYMENT NAMES INIT ########
self.deployment_names.append(deployment.litellm_params.model)
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
# for get_available_deployment, we use the litellm_param["rpm"]
# in this snippet we also set rpm to be a litellm_param
if (
deployment.litellm_params.rpm is None
and getattr(deployment, "rpm", None) is not None
):
deployment.litellm_params.rpm = getattr(deployment, "rpm")
if (
deployment.litellm_params.tpm is None
and getattr(deployment, "tpm", None) is not None
):
deployment.litellm_params.tpm = getattr(deployment, "tpm")
#### VALIDATE MODEL ########
# check if model provider in supported providers
(
_model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=deployment.litellm_params.model,
custom_llm_provider=deployment.litellm_params.get(
"custom_llm_provider", None
),
)
# Check if user is trying to use model_name == "*"
# this is a catch all model for their specific api key
if deployment.model_name == "*":
self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", []) or []
for data_source in data_sources:
params = data_source.get("parameters", {})
for param_key in ["endpoint", "key"]:
# if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var
if param_key in params and params[param_key].startswith("os.environ/"):
env_name = params[param_key].replace("os.environ/", "")
params[param_key] = os.environ.get(env_name, "")
# done reading model["litellm_params"]
if custom_llm_provider not in litellm.provider_list:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features == True:
print("Auto inferring region") # noqa
"""
Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up
"""
try:
if (
"azure" in deployment.litellm_params.model
and deployment.litellm_params.region_name is None
):
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region
except Exception as e:
verbose_router_logger.debug(
"Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e)
)
)
pass # [NON-BLOCKING]
return deployment
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
"""
Parameters:
- deployment: Deployment - the deployment to be added to the Router
Returns:
- The added deployment
- OR None (if deployment already exists)
"""
# check if deployment already exists
if deployment.model_info.id in self.get_model_ids():
return None
# add to model list
_deployment = deployment.to_json(exclude_none=True)
self.model_list.append(_deployment)
# initialize client
self._add_deployment(deployment=deployment)
# add to model names
self.model_names.append(deployment.model_name)
return deployment
def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]:
"""
Add or update deployment
Parameters:
- deployment: Deployment - the deployment to be added to the Router
Returns:
- The added/updated deployment
"""
# check if deployment already exists
_deployment_model_id = deployment.model_info.id or ""
_deployment_on_router: Optional[Deployment] = self.get_deployment(
model_id=_deployment_model_id
)
if _deployment_on_router is not None:
# deployment with this model_id exists on the router
if deployment.litellm_params == _deployment_on_router.litellm_params:
# No need to update
return None
# if there is a new litellm param -> then update the deployment
# remove the previous deployment
removal_idx: Optional[int] = None
for idx, model in enumerate(self.model_list):
if model["model_info"]["id"] == deployment.model_info.id:
removal_idx = idx
if removal_idx is not None:
self.model_list.pop(removal_idx)
else:
# if the model_id is not in router
self.add_deployment(deployment=deployment)
return deployment
def delete_deployment(self, id: str) -> Optional[Deployment]:
"""
Parameters:
- id: str - the id of the deployment to be deleted
Returns:
- The deleted deployment
- OR None (if deleted deployment not found)
"""
deployment_idx = None
for idx, m in enumerate(self.model_list):
if m["model_info"]["id"] == id:
deployment_idx = idx
try:
if deployment_idx is not None:
item = self.model_list.pop(deployment_idx)
return item
else:
return None
except:
return None
def get_deployment(self, model_id: str) -> Optional[Deployment]:
"""
Returns -> Deployment or None
Raise Exception -> if model found in invalid format
"""
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
if model_id == model["model_info"]["id"]:
if isinstance(model, dict):
return Deployment(**model)
elif isinstance(model, Deployment):
return model
else:
raise Exception("Model invalid format - {}".format(type(model)))
return None
def get_model_info(self, id: str) -> Optional[dict]:
"""
For a given model id, return the model info
"""
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
if id == model["model_info"]["id"]:
return model
return None
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
model_group_info: Optional[ModelGroupInfo] = None
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
except Exception as e:
model_info = None
# get llm provider
try:
model, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model,
custom_llm_provider=litellm_params.custom_llm_provider,
)
except litellm.exceptions.BadRequestError as e:
continue
if model_info is None:
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=llm_provider
)
if supported_openai_params is None:
supported_openai_params = []
model_info = ModelMapInfo(
max_tokens=None,
max_input_tokens=None,
max_output_tokens=None,
input_cost_per_token=0,
output_cost_per_token=0,
litellm_provider=llm_provider,
mode="chat",
supported_openai_params=supported_openai_params,
)
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
)
else:
# if max_input_tokens > curr
# if max_output_tokens > curr
# if input_cost_per_token > curr
# if output_cost_per_token > curr
# supports_parallel_function_calling == True
# supports_vision == True
# supports_function_calling == True
if llm_provider not in model_group_info.providers:
model_group_info.providers.append(llm_provider)
if (
model_info.get("max_input_tokens", None) is not None
and model_info["max_input_tokens"] is not None
and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
)
):
model_group_info.max_input_tokens = model_info[
"max_input_tokens"
]
if (
model_info.get("max_output_tokens", None) is not None
and model_info["max_output_tokens"] is not None
and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
)
):
model_group_info.max_output_tokens = model_info[
"max_output_tokens"
]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
> model_group_info.input_cost_per_token
):
model_group_info.input_cost_per_token = model_info[
"input_cost_per_token"
]
if model_info.get("output_cost_per_token", None) is not None and (
model_group_info.output_cost_per_token is None
or model_info["output_cost_per_token"]
> model_group_info.output_cost_per_token
):
model_group_info.output_cost_per_token = model_info[
"output_cost_per_token"
]
if (
model_info.get("supports_parallel_function_calling", None)
is not None
and model_info["supports_parallel_function_calling"] is True # type: ignore
):
model_group_info.supports_parallel_function_calling = True
if (
model_info.get("supports_vision", None) is not None
and model_info["supports_vision"] is True # type: ignore
):
model_group_info.supports_vision = True
if (
model_info.get("supports_function_calling", None) is not None
and model_info["supports_function_calling"] is True # type: ignore
):
model_group_info.supports_function_calling = True
if (
model_info.get("supported_openai_params", None) is not None
and model_info["supported_openai_params"] is not None
):
model_group_info.supported_openai_params = model_info[
"supported_openai_params"
]
return model_group_info
def get_model_ids(self) -> List[str]:
"""
Returns list of model id's.
"""
ids = []
for model in self.model_list:
if "model_info" in model and "id" in model["model_info"]:
id = model["model_info"]["id"]
ids.append(id)
return ids
def get_model_names(self) -> List[str]:
return self.model_names
def get_model_list(self):
if hasattr(self, "model_list"):
return self.model_list
return None
def get_settings(self):
"""
Get router settings method, returns a dictionary of the settings and their values.
For example get the set values for routing_strategy_args, routing_strategy, allowed_fails, cooldown_time, num_retries, timeout, max_retries, retry_after
"""
_all_vars = vars(self)
_settings_to_return = {}
vars_to_include = [
"routing_strategy_args",
"routing_strategy",
"allowed_fails",
"cooldown_time",
"num_retries",
"timeout",
"max_retries",
"retry_after",
"fallbacks",
"context_window_fallbacks",
"model_group_retry_policy",
]
for var in vars_to_include:
if var in _all_vars:
_settings_to_return[var] = _all_vars[var]
if (
var == "routing_strategy_args"
and self.routing_strategy == "latency-based-routing"
):
_settings_to_return[var] = self.lowestlatency_logger.routing_args.json()
return _settings_to_return
def update_settings(self, **kwargs):
# only the following settings are allowed to be configured
_allowed_settings = [
"routing_strategy_args",
"routing_strategy",
"allowed_fails",
"cooldown_time",
"num_retries",
"timeout",
"max_retries",
"retry_after",
"fallbacks",
"context_window_fallbacks",
"model_group_retry_policy",
]
_int_settings = [
"timeout",
"num_retries",
"retry_after",
"allowed_fails",
"cooldown_time",
]
_existing_router_settings = self.get_settings()
for var in kwargs:
if var in _allowed_settings:
if var in _int_settings:
_casted_value = int(kwargs[var])
setattr(self, var, _casted_value)
else:
# only run routing strategy init if it has changed
if (
var == "routing_strategy"
and _existing_router_settings["routing_strategy"] != kwargs[var]
):
self.routing_strategy_init(
routing_strategy=kwargs[var],
routing_strategy_args=kwargs.get(
"routing_strategy_args", {}
),
)
setattr(self, var, kwargs[var])
else:
verbose_router_logger.debug("Setting {} is not allowed".format(var))
verbose_router_logger.debug(f"Updated Router settings: {self.get_settings()}")
def _get_client(self, deployment, kwargs, client_type=None):
"""
Returns the appropriate client based on the given deployment, kwargs, and client_type.
Parameters:
deployment (dict): The deployment dictionary containing the clients.
kwargs (dict): The keyword arguments passed to the function.
client_type (str): The type of client to return.
Returns:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "max_parallel_requests":
cache_key = "{}_max_parallel_requests_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":
if kwargs.get("stream") == True:
cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
cache_key = f"{model_id}_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
if kwargs.get("stream") == True:
cache_key = f"{model_id}_stream_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
else:
cache_key = f"{model_id}_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
"""
Re-initialize the client
"""
self.set_client(model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
def _pre_call_checks(
self,
model: str,
healthy_deployments: List,
messages: List[Dict[str, str]],
request_kwargs: Optional[dict] = None,
):
"""
Filter out model in model group, if:
- model context window < message length
- filter models above rpm limits
- if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling
"""
verbose_router_logger.debug(
f"Starting Pre-call checks for deployments in model={model}"
)
_returned_deployments = copy.deepcopy(healthy_deployments)
invalid_model_indices = []
try:
input_tokens = litellm.token_counter(messages=messages)
except Exception as e:
return _returned_deployments
_context_window_error = False
_rate_limit_error = False
## get model group RPM ##
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {}
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
try:
base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None:
base_model = deployment.get("litellm_params", {}).get(
"base_model", None
)
model = base_model or deployment.get("litellm_params", {}).get(
"model", None
)
model_info = litellm.get_model_info(model=model)
if (
isinstance(model_info, dict)
and model_info.get("max_input_tokens", None) is not None
):
if (
isinstance(model_info["max_input_tokens"], int)
and input_tokens > model_info["max_input_tokens"]
):
invalid_model_indices.append(idx)
_context_window_error = True
continue
except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
_litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "")
## RPM CHECK ##
### get local router cache ###
current_request_cache_local = (
self.cache.get_cache(key=model_id, local_only=True) or 0
)
### get usage based cache ###
if (
isinstance(model_group_cache, dict)
and self.routing_strategy != "usage-based-routing-v2"
):
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max(
current_request_cache_local, model_group_cache[model_id]
)
if (
isinstance(_litellm_params, dict)
and _litellm_params.get("rpm", None) is not None
):
if (
isinstance(_litellm_params["rpm"], int)
and _litellm_params["rpm"] <= current_request
):
invalid_model_indices.append(idx)
_rate_limit_error = True
continue
## REGION CHECK ##
if (
request_kwargs is not None
and request_kwargs.get("allowed_model_region") is not None
and request_kwargs["allowed_model_region"] == "eu"
):
if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str
):
# check if in allowed_model_region
if (
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
== False
):
invalid_model_indices.append(idx)
continue
else:
verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, request_kwargs.get("allowed_model_region")
)
)
# filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx)
continue
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param
if request_kwargs is not None and litellm.drop_params == False:
# get supported params
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
)
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if supported_openai_params is None:
continue
else:
# check the non-default openai params in request kwargs
non_default_params = litellm.utils.get_non_default_params(
passed_params=request_kwargs
)
special_params = ["response_format"]
# check if all params are supported
for k, v in non_default_params.items():
if k not in supported_openai_params and k in special_params:
# if not -> invalid model
verbose_router_logger.debug(
f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}"
)
invalid_model_indices.append(idx)
if len(invalid_model_indices) == len(_returned_deployments):
"""
- no healthy deployments available b/c context window checks or rate limit error
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
"""
if _rate_limit_error == True: # allow generic fallback logic to take place
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
)
elif _context_window_error == True:
raise litellm.ContextWindowExceededError(
message="Context Window exceeded for given call",
model=model,
llm_provider="",
response=httpx.Response(
status_code=400,
request=httpx.Request("GET", "https://example.com"),
),
)
if len(invalid_model_indices) > 0:
for idx in reversed(invalid_model_indices):
_returned_deployments.pop(idx)
return _returned_deployments
def _common_checks_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
) -> Tuple[str, Union[list, dict]]:
"""
Common checks for 'get_available_deployment' across sync + async call.
If 'healthy_deployments' returned is None, this means the user chose a specific deployment
Returns
- Dict, if specific model chosen
- List, if multiple models chosen
"""
# check if aliases set on litellm model alias map
if specific_deployment == True:
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
for deployment in self.model_list:
deployment_model = deployment.get("litellm_params").get("model")
if deployment_model == model:
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
# return the first deployment where the `model` matches the specificed deployment name
return deployment_model, deployment
raise ValueError(
f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}"
)
if model in self.model_group_alias:
verbose_router_logger.debug(
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
)
model = self.model_group_alias[model]
if model not in self.model_names and self.default_deployment is not None:
updated_deployment = copy.deepcopy(
self.default_deployment
) # self.default_deployment
updated_deployment["litellm_params"]["model"] = model
return model, updated_deployment
## get healthy deployments
### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = [
m for m in self.model_list if m["litellm_params"]["model"] == model
]
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
if len(healthy_deployments) == 0:
raise ValueError(
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
)
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
return model, healthy_deployments
async def async_get_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
request_kwargs: Optional[Dict] = None,
):
"""
Async implementation of 'get_available_deployments'.
Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps).
"""
if (
self.routing_strategy != "usage-based-routing-v2"
and self.routing_strategy != "simple-shuffle"
and self.routing_strategy != "cost-based-routing"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment(
model=model,
messages=messages,
input=input,
specific_deployment=specific_deployment,
request_kwargs=request_kwargs,
)
model, healthy_deployments = self._common_checks_available_deployment(
model=model,
messages=messages,
input=input,
specific_deployment=specific_deployment,
) # type: ignore
if isinstance(healthy_deployments, dict):
return healthy_deployments
# filter out the deployments currently cooling down
deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = await self._async_get_cooldown_deployments()
verbose_router_logger.debug(
f"async cooldown deployments: {cooldown_deployments}"
)
# Find deployments in model_list whose model_id is cooling down
for deployment in healthy_deployments:
deployment_id = deployment["model_info"]["id"]
if deployment_id in cooldown_deployments:
deployments_to_remove.append(deployment)
# remove unhealthy deployments from healthy deployments
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
# filter pre-call checks
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
)
if len(healthy_deployments) == 0:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
)
if (
self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None
):
deployment = await self.lowesttpm_logger_v2.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
messages=messages,
input=input,
)
if (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
deployment = await self.lowestcost_logger.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
messages=messages,
input=input,
)
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None:
# use weight-random pick if rpms provided
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\nrpms {rpms}")
total_rpm = sum(rpms)
weights = [rpm / total_rpm for rpm in rpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
if tpm is not None:
# use weight-random pick if rpms provided
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\ntpms {tpms}")
total_tpm = sum(tpms)
weights = [tpm / total_tpm for tpm in tpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
return deployment
def get_available_deployment(
self,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
request_kwargs: Optional[Dict] = None,
):
"""
Returns the deployment based on routing strategy
"""
# users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg
# When this was no explicit we had several issues with fallbacks timing out
model, healthy_deployments = self._common_checks_available_deployment(
model=model,
messages=messages,
input=input,
specific_deployment=specific_deployment,
)
if isinstance(healthy_deployments, dict):
return healthy_deployments
# filter out the deployments currently cooling down
deployments_to_remove = []
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
cooldown_deployments = self._get_cooldown_deployments()
verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}")
# Find deployments in model_list whose model_id is cooling down
for deployment in healthy_deployments:
deployment_id = deployment["model_info"]["id"]
if deployment_id in cooldown_deployments:
deployments_to_remove.append(deployment)
# remove unhealthy deployments from healthy deployments
for deployment in deployments_to_remove:
healthy_deployments.remove(deployment)
# filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages
)
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
deployment = self.leastbusy_logger.get_available_deployments(
model_group=model, healthy_deployments=healthy_deployments # type: ignore
)
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None:
# use weight-random pick if rpms provided
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\nrpms {rpms}")
total_rpm = sum(rpms)
weights = [rpm / total_rpm for rpm in rpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
if tpm is not None:
# use weight-random pick if rpms provided
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\ntpms {tpms}")
total_tpm = sum(tpms)
weights = [tpm / total_tpm for tpm in tpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]
elif (
self.routing_strategy == "latency-based-routing"
and self.lowestlatency_logger is not None
):
deployment = self.lowestlatency_logger.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
request_kwargs=request_kwargs,
)
elif (
self.routing_strategy == "usage-based-routing"
and self.lowesttpm_logger is not None
):
deployment = self.lowesttpm_logger.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
messages=messages,
input=input,
)
elif (
self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None
):
deployment = self.lowesttpm_logger_v2.get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
messages=messages,
input=input,
)
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
return deployment
def _track_deployment_metrics(self, deployment, response=None):
try:
litellm_params = deployment["litellm_params"]
api_base = litellm_params.get("api_base", "")
model = litellm_params.get("model", "")
model_id = deployment.get("model_info", {}).get("id", None)
if response is None:
# update self.deployment_stats
if model_id is not None:
self._update_usage(model_id) # update in-memory cache for tracking
if model_id in self.deployment_stats:
# only update num_requests
self.deployment_stats[model_id]["num_requests"] += 1
else:
self.deployment_stats[model_id] = {
"api_base": api_base,
"model": model,
"num_requests": 1,
}
else:
# check response_ms and update num_successes
if isinstance(response, dict):
response_ms = response.get("_response_ms", 0)
else:
response_ms = 0
if model_id is not None:
if model_id in self.deployment_stats:
# check if avg_latency exists
if "avg_latency" in self.deployment_stats[model_id]:
# update avg_latency
self.deployment_stats[model_id]["avg_latency"] = (
self.deployment_stats[model_id]["avg_latency"]
+ response_ms
) / self.deployment_stats[model_id]["num_successes"]
else:
self.deployment_stats[model_id]["avg_latency"] = response_ms
# check if num_successes exists
if "num_successes" in self.deployment_stats[model_id]:
self.deployment_stats[model_id]["num_successes"] += 1
else:
self.deployment_stats[model_id]["num_successes"] = 1
else:
self.deployment_stats[model_id] = {
"api_base": api_base,
"model": model,
"num_successes": 1,
"avg_latency": response_ms,
}
if self.set_verbose == True and self.debug_level == "DEBUG":
from pprint import pformat
# Assuming self.deployment_stats is your dictionary
formatted_stats = pformat(self.deployment_stats)
# Assuming verbose_router_logger is your logger
verbose_router_logger.info(
"self.deployment_stats: \n%s", formatted_stats
)
except Exception as e:
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")
def get_num_retries_from_retry_policy(
self, exception: Exception, model_group: Optional[str] = None
):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
retry_policy = self.model_group_retry_policy.get(model_group, None)
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
):
return retry_policy.BadRequestErrorRetries
if (
isinstance(exception, litellm.AuthenticationError)
and retry_policy.AuthenticationErrorRetries is not None
):
return retry_policy.AuthenticationErrorRetries
if (
isinstance(exception, litellm.Timeout)
and retry_policy.TimeoutErrorRetries is not None
):
return retry_policy.TimeoutErrorRetries
if (
isinstance(exception, litellm.RateLimitError)
and retry_policy.RateLimitErrorRetries is not None
):
return retry_policy.RateLimitErrorRetries
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and retry_policy.ContentPolicyViolationErrorRetries is not None
):
return retry_policy.ContentPolicyViolationErrorRetries
def get_allowed_fails_from_policy(self, exception: Exception):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy
if allowed_fails_policy is None:
return None
if (
isinstance(exception, litellm.BadRequestError)
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
):
return allowed_fails_policy.BadRequestErrorAllowedFails
if (
isinstance(exception, litellm.AuthenticationError)
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
):
return allowed_fails_policy.AuthenticationErrorAllowedFails
if (
isinstance(exception, litellm.Timeout)
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
):
return allowed_fails_policy.TimeoutErrorAllowedFails
if (
isinstance(exception, litellm.RateLimitError)
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
):
return allowed_fails_policy.RateLimitErrorAllowedFails
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
):
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails
def _initialize_alerting(self):
from litellm.integrations.slack_alerting import SlackAlerting
router_alerting_config: AlertingConfig = self.alerting_config
_slack_alerting_logger = SlackAlerting(
alerting_threshold=router_alerting_config.alerting_threshold,
alerting=["slack"],
default_webhook_url=router_alerting_config.webhook_url,
)
litellm.callbacks.append(_slack_alerting_logger)
litellm.success_callback.append(
_slack_alerting_logger.response_taking_too_long_callback
)
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
def send_deployment_cooldown_alert(
self,
deployment_id: str,
exception_status: Union[str, int],
cooldown_time: float,
):
try:
from litellm.proxy.proxy_server import proxy_logging_obj
# trigger slack alert saying deployment is in cooldown
if (
proxy_logging_obj is not None
and proxy_logging_obj.alerting is not None
and "slack" in proxy_logging_obj.alerting
):
_deployment = self.get_deployment(model_id=deployment_id)
if _deployment is None:
return
_litellm_params = _deployment["litellm_params"]
temp_litellm_params = copy.deepcopy(_litellm_params)
temp_litellm_params = dict(temp_litellm_params)
_model_name = _deployment.get("model_name", None)
_api_base = litellm.get_api_base(
model=_model_name, optional_params=temp_litellm_params
)
# asyncio.create_task(
# proxy_logging_obj.slack_alerting_instance.send_alert(
# message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
# alert_type="cooldown_deployment",
# level="Low",
# )
# )
except Exception as e:
pass
def flush_cache(self):
litellm.cache = None
self.cache.flush_cache()
def reset(self):
## clean up on close
litellm.success_callback = []
litellm.__async_success_callback = []
litellm.failure_callback = []
litellm._async_failure_callback = []
self.retry_policy = None
self.flush_cache()