# +-----------------------------------------------+ # | | # | Give Feedback / Get Help | # | https://github.com/BerriAI/litellm/issues/new | # | | # +-----------------------------------------------+ # # Thank you ! We ❤️ you! - Krrish & Ishaan import copy, httpx from datetime import datetime from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple, TypedDict from typing_extensions import overload import random, threading, time, traceback, uuid import litellm, openai, hashlib, json from litellm.caching import RedisCache, InMemoryCache, DualCache import datetime as datetime_og import logging, asyncio import inspect, concurrent from openai import AsyncOpenAI from collections import defaultdict from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.llms.custom_httpx.azure_dall_e_2 import ( CustomHTTPTransport, AsyncCustomHTTPTransport, ) from litellm.utils import ( ModelResponse, CustomStreamWrapper, get_utc_datetime, calculate_max_parallel_requests, _is_region_eu, ) import copy from litellm._logging import verbose_router_logger import logging from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.types.router import ( Deployment, ModelInfo, LiteLLM_Params, RouterErrors, updateDeployment, updateLiteLLMParams, RetryPolicy, AllowedFailsPolicy, AlertingConfig, DeploymentTypedDict, ModelGroupInfo, AssistantsTypedDict, ) from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.types.llms.openai import ( AsyncCursorPage, Assistant, Thread, Attachment, OpenAIMessage, Run, AssistantToolParam, ) from litellm.scheduler import Scheduler, FlowItem from typing import Iterable from litellm.router_utils.handle_error import send_llm_exception_alert class Router: model_names: List = [] cache_responses: Optional[bool] = False default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour tenacity = None leastbusy_logger: Optional[LeastBusyLoggingHandler] = None lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None def __init__( self, model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None, ## ASSISTANTS API ## assistants_config: Optional[AssistantsTypedDict] = None, ## CACHING ## redis_url: Optional[str] = None, redis_host: Optional[str] = None, redis_port: Optional[int] = None, redis_password: Optional[str] = None, cache_responses: Optional[bool] = False, cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) caching_groups: Optional[ List[tuple] ] = None, # if you want to cache across model groups client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds ## SCHEDULER ## polling_interval: Optional[float] = None, ## RELIABILITY ## num_retries: Optional[int] = None, timeout: Optional[float] = None, default_litellm_params: Optional[ dict ] = None, # default params for Router.chat.completion.create default_max_parallel_requests: Optional[int] = None, set_verbose: bool = False, debug_level: Literal["DEBUG", "INFO"] = "INFO", default_fallbacks: Optional[ List[str] ] = None, # generic fallbacks, works across all deployments fallbacks: List = [], context_window_fallbacks: List = [], model_group_alias: Optional[dict] = {}, enable_pre_call_checks: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request retry_policy: Optional[ RetryPolicy ] = None, # set custom retries for different exceptions model_group_retry_policy: Optional[ Dict[str, RetryPolicy] ] = {}, # set custom retry policies based on model group allowed_fails: Optional[ int ] = None, # Number of times a deployment can failbefore being added to cooldown allowed_fails_policy: Optional[ AllowedFailsPolicy ] = None, # set custom allowed fails policy cooldown_time: Optional[ float ] = None, # (seconds) time to cooldown a deployment after failure routing_strategy: Literal[ "simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing", "usage-based-routing-v2", ] = "simple-shuffle", routing_strategy_args: dict = {}, # just for latency-based routing semaphore: Optional[asyncio.Semaphore] = None, alerting_config: Optional[AlertingConfig] = None, ) -> None: """ Initialize the Router class with the given parameters for caching, reliability, and routing strategy. Args: model_list (Optional[list]): List of models to be used. Defaults to None. redis_url (Optional[str]): URL of the Redis server. Defaults to None. redis_host (Optional[str]): Hostname of the Redis server. Defaults to None. redis_port (Optional[int]): Port of the Redis server. Defaults to None. redis_password (Optional[str]): Password of the Redis server. Defaults to None. cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False. cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}. caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None. client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600. polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2. timeout (Optional[float]): Timeout for requests. Defaults to None. default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}. set_verbose (bool): Flag to set verbose mode. Defaults to False. debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO". fallbacks (List): List of fallback options. Defaults to []. context_window_fallbacks (List): List of context window fallback options. Defaults to []. enable_pre_call_checks (boolean): Filter out deployments which are outside context window limits for a given prompt model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}. retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0. allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None. cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1. routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing", "cost-based-routing"]): Routing strategy. Defaults to "simple-shuffle". routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}. alerting_config (AlertingConfig): Slack alerting configuration. Defaults to None. Returns: Router: An instance of the litellm.Router class. Example Usage: ```python from litellm import Router model_list = [ { "model_name": "azure-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call "model": "azure/", "api_key": , "api_version": , "api_base": }, }, { "model_name": "azure-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call "model": "azure/", "api_key": , "api_version": , "api_base": }, }, { "model_name": "openai-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call "model": "gpt-3.5-turbo", "api_key": , }, ] router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) ``` """ if semaphore: self.semaphore = semaphore self.set_verbose = set_verbose self.debug_level = debug_level self.enable_pre_call_checks = enable_pre_call_checks if self.set_verbose == True: if debug_level == "INFO": verbose_router_logger.setLevel(logging.INFO) elif debug_level == "DEBUG": verbose_router_logger.setLevel(logging.DEBUG) self.assistants_config = assistants_config self.deployment_names: List = ( [] ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache redis_cache = None cache_config = {} self.client_ttl = client_ttl if redis_url is not None or ( redis_host is not None and redis_port is not None and redis_password is not None ): cache_type = "redis" if redis_url is not None: cache_config["url"] = redis_url if redis_host is not None: cache_config["host"] = redis_host if redis_port is not None: cache_config["port"] = str(redis_port) # type: ignore if redis_password is not None: cache_config["password"] = redis_password # Add additional key-value pairs from cache_kwargs cache_config.update(cache_kwargs) redis_cache = RedisCache(**cache_config) if cache_responses: if litellm.cache is None: # the cache can be initialized on the proxy server. We should not overwrite it litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore self.cache_responses = cache_responses self.cache = DualCache( redis_cache=redis_cache, in_memory_cache=InMemoryCache() ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ### SCHEDULER ### self.scheduler = Scheduler( polling_interval=polling_interval, redis_cache=redis_cache ) self.default_deployment = None # use this to track the users default deployment, when they want to use model = * self.default_max_parallel_requests = default_max_parallel_requests if model_list is not None: model_list = copy.deepcopy(model_list) self.set_model_list(model_list) self.healthy_deployments: List = self.model_list # type: ignore for m in model_list: self.deployment_latency_map[m["litellm_params"]["model"]] = 0 else: self.model_list: List = ( [] ) # initialize an empty list - to allow _add_deployment and delete_deployment to work self.allowed_fails = allowed_fails or litellm.allowed_fails self.cooldown_time = cooldown_time or 60 self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown if num_retries is not None: self.num_retries = num_retries elif litellm.num_retries is not None: self.num_retries = litellm.num_retries else: self.num_retries = openai.DEFAULT_MAX_RETRIES self.timeout = timeout or litellm.request_timeout self.retry_after = retry_after self.routing_strategy = routing_strategy ## SETTING FALLBACKS ## ### validate if it's set + in correct format _fallbacks = fallbacks or litellm.fallbacks self.validate_fallbacks(fallback_param=_fallbacks) ### set fallbacks self.fallbacks = _fallbacks if default_fallbacks is not None or litellm.default_fallbacks is not None: _fallbacks = default_fallbacks or litellm.default_fallbacks if self.fallbacks is not None: self.fallbacks.append({"*": _fallbacks}) else: self.fallbacks = [{"*": _fallbacks}] self.context_window_fallbacks = ( context_window_fallbacks or litellm.context_window_fallbacks ) self.total_calls: defaultdict = defaultdict( int ) # dict to store total calls made to each model self.fail_calls: defaultdict = defaultdict( int ) # dict to store fail_calls made to each model self.success_calls: defaultdict = defaultdict( int ) # dict to store success_calls made to each model self.previous_models: List = ( [] ) # list to store failed calls (passed in as metadata to next call) self.model_group_alias: dict = ( model_group_alias or {} ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group # make Router.chat.completions.create compatible for openai.chat.completions.create default_litellm_params = default_litellm_params or {} self.chat = litellm.Chat(params=default_litellm_params, router_obj=self) # default litellm args self.default_litellm_params = default_litellm_params self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("max_retries", 0) self.default_litellm_params.setdefault("metadata", {}).update( {"caching_groups": caching_groups} ) self.deployment_stats: dict = {} # used for debugging load balancing """ deployment_stats = { "122999-2828282-277: { "model": "gpt-3", "api_base": "http://localhost:4000", "num_requests": 20, "avg_latency": 0.001, "num_failures": 0, "num_successes": 20 } } """ ### ROUTING SETUP ### self.routing_strategy_init( routing_strategy=routing_strategy, routing_strategy_args=routing_strategy_args, ) ## COOLDOWNS ## if isinstance(litellm.failure_callback, list): litellm.failure_callback.append(self.deployment_callback_on_failure) else: litellm.failure_callback = [self.deployment_callback_on_failure] print( # noqa f"Intialized router with Routing strategy: {self.routing_strategy}\n\nRouting fallbacks: {self.fallbacks}\n\nRouting context window fallbacks: {self.context_window_fallbacks}\n\nRouter Redis Caching={self.cache.redis_cache}" ) # noqa self.routing_strategy_args = routing_strategy_args self.retry_policy: Optional[RetryPolicy] = retry_policy self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( model_group_retry_policy ) self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy self.alerting_config: Optional[AlertingConfig] = alerting_config if self.alerting_config is not None: self._initialize_alerting() def validate_fallbacks(self, fallback_param: Optional[List]): if fallback_param is None: return for fallback_dict in fallback_param: if not isinstance(fallback_dict, dict): raise ValueError(f"Item '{fallback_dict}' is not a dictionary.") if len(fallback_dict) != 1: raise ValueError( f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys." ) def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): if routing_strategy == "least-busy": self.leastbusy_logger = LeastBusyLoggingHandler( router_cache=self.cache, model_list=self.model_list ) ## add callback if isinstance(litellm.input_callback, list): litellm.input_callback.append(self.leastbusy_logger) # type: ignore else: litellm.input_callback = [self.leastbusy_logger] # type: ignore if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.leastbusy_logger) # type: ignore elif routing_strategy == "usage-based-routing": self.lowesttpm_logger = LowestTPMLoggingHandler( router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger) # type: ignore elif routing_strategy == "usage-based-routing-v2": self.lowesttpm_logger_v2 = LowestTPMLoggingHandler_v2( router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowesttpm_logger_v2) # type: ignore elif routing_strategy == "latency-based-routing": self.lowestlatency_logger = LowestLatencyLoggingHandler( router_cache=self.cache, model_list=self.model_list, routing_args=routing_strategy_args, ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestlatency_logger) # type: ignore elif routing_strategy == "cost-based-routing": self.lowestcost_logger = LowestCostLoggingHandler( router_cache=self.cache, model_list=self.model_list, routing_args={}, ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestcost_logger) # type: ignore def print_deployment(self, deployment: dict): """ returns a copy of the deployment with the api key masked """ try: _deployment_copy = copy.deepcopy(deployment) litellm_params: dict = _deployment_copy["litellm_params"] if "api_key" in litellm_params: litellm_params["api_key"] = litellm_params["api_key"][:2] + "*" * 10 return _deployment_copy except Exception as e: verbose_router_logger.debug( f"Error occurred while printing deployment - {str(e)}" ) raise e ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS def completion( self, model: str, messages: List[Dict[str, str]], **kwargs ) -> Union[ModelResponse, CustomStreamWrapper]: """ Example usage: response = router.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}] """ try: verbose_router_logger.debug(f"router.completion(model={model},..)") kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._completion timeout = kwargs.get("request_timeout", self.timeout) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = self.function_with_fallbacks(**kwargs) return response except Exception as e: raise e def _completion( self, model: str, messages: List[Dict[str, str]], **kwargs ) -> Union[ModelResponse, CustomStreamWrapper]: model_name = None try: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment( model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "api_base": deployment.get("litellm_params", {}).get("api_base"), "model_info": deployment.get("model_info", {}), } ) data = deployment["litellm_params"].copy() kwargs["model_info"] = deployment.get("model_info", {}) model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) self.routing_strategy_pre_call_checks(deployment=deployment) response = litellm.completion( **{ **data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs, } ) verbose_router_logger.info( f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.completion(model={model_name})\033[31m Exception {str(e)}\033[0m" ) raise e # fmt: off @overload async def acompletion( self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs ) -> CustomStreamWrapper: ... @overload async def acompletion( self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs ) -> ModelResponse: ... @overload async def acompletion( self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs ) -> Union[CustomStreamWrapper, ModelResponse]: ... # fmt: on # The actual implementation of the function async def acompletion( self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs ): try: kwargs["model"] = model kwargs["messages"] = messages kwargs["stream"] = stream kwargs["original_function"] = self._acompletion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def _acompletion( self, model: str, messages: List[Dict[str, str]], **kwargs ) -> Union[ModelResponse, CustomStreamWrapper]: """ - Get an available deployment - call it with a semaphore over the call - semaphore specific to it's rpm - in the semaphore, make a check against it's local rpm before running """ model_name = None try: verbose_router_logger.debug( f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" ) deployment = await self.async_get_available_deployment( model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None), request_kwargs=kwargs, ) # debug how often this deployment picked self._track_deployment_metrics(deployment=deployment) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), "api_base": deployment.get("litellm_params", {}).get("api_base"), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 timeout = ( data.get( "timeout", None ) # timeout set on litellm_params for this deployment or self.timeout # timeout set on router or kwargs.get( "timeout", None ) # this uses default_litellm_params when nothing is set ) _response = litellm.acompletion( **{ **data, "messages": messages, "caching": self.cache_responses, "client": model_client, "timeout": timeout, **kwargs, } ) rpm_semaphore = self._get_client( deployment=deployment, kwargs=kwargs, client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( deployment=deployment ) response = await _response else: await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await _response self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" ) # debug how often this deployment picked self._track_deployment_metrics(deployment=deployment, response=response) return response except Exception as e: verbose_router_logger.info( f"litellm.acompletion(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e async def abatch_completion( self, models: List[str], messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], **kwargs, ): """ Async Batch Completion. Used for 2 scenarios: 1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this 2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this Example Request for 1 request to N models: ``` response = await router.abatch_completion( models=["gpt-3.5-turbo", "groq-llama"], messages=[ {"role": "user", "content": "is litellm becoming a better product ?"} ], max_tokens=15, ) ``` Example Request for N requests to M models: ``` response = await router.abatch_completion( models=["gpt-3.5-turbo", "groq-llama"], messages=[ [{"role": "user", "content": "is litellm becoming a better product ?"}], [{"role": "user", "content": "who is this"}], ], ) ``` """ ############## Helpers for async completion ################## async def _async_completion_no_exceptions( model: str, messages: List[Dict[str, str]], **kwargs ): """ Wrapper around self.async_completion that catches exceptions and returns them as a result """ try: return await self.acompletion(model=model, messages=messages, **kwargs) except Exception as e: return e async def _async_completion_no_exceptions_return_idx( model: str, messages: List[Dict[str, str]], idx: int, # index of message this response corresponds to **kwargs, ): """ Wrapper around self.async_completion that catches exceptions and returns them as a result """ try: return ( await self.acompletion(model=model, messages=messages, **kwargs), idx, ) except Exception as e: return e, idx ############## Helpers for async completion ################## if isinstance(messages, list) and all(isinstance(m, dict) for m in messages): _tasks = [] for model in models: # add each task but if the task fails _tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore response = await asyncio.gather(*_tasks) return response elif isinstance(messages, list) and all(isinstance(m, list) for m in messages): _tasks = [] for idx, message in enumerate(messages): for model in models: # Request Number X, Model Number Y _tasks.append( _async_completion_no_exceptions_return_idx( model=model, idx=idx, messages=message, **kwargs # type: ignore ) ) responses = await asyncio.gather(*_tasks) final_responses: List[List[Any]] = [[] for _ in range(len(messages))] for response in responses: if isinstance(response, tuple): final_responses[response[1]].append(response[0]) else: final_responses[0].append(response) return final_responses async def abatch_completion_one_model_multiple_requests( self, model: str, messages: List[List[Dict[str, str]]], **kwargs ): """ Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router Use this for sending multiple requests to 1 model Args: model (List[str]): model group messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request **kwargs: additional kwargs Usage: response = await self.abatch_completion_one_model_multiple_requests( model="gpt-3.5-turbo", messages=[ [{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}], [{"role": "user", "content": "hello good mornign"}], ] ) """ async def _async_completion_no_exceptions( model: str, messages: List[Dict[str, str]], **kwargs ): """ Wrapper around self.async_completion that catches exceptions and returns them as a result """ try: return await self.acompletion(model=model, messages=messages, **kwargs) except Exception as e: return e _tasks = [] for message_request in messages: # add each task but if the task fails _tasks.append( _async_completion_no_exceptions( model=model, messages=message_request, **kwargs ) ) response = await asyncio.gather(*_tasks) return response # fmt: off @overload async def abatch_completion_fastest_response( self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs ) -> CustomStreamWrapper: ... @overload async def abatch_completion_fastest_response( self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs ) -> ModelResponse: ... # fmt: on async def abatch_completion_fastest_response( self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs, ): """ model - List of comma-separated model names. E.g. model="gpt-4, gpt-3.5-turbo" Returns fastest response from list of model names. OpenAI-compatible endpoint. """ models = [m.strip() for m in model.split(",")] async def _async_completion_no_exceptions( model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any ) -> Union[ModelResponse, CustomStreamWrapper, Exception]: """ Wrapper around self.acompletion that catches exceptions and returns them as a result """ try: return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore except asyncio.CancelledError: verbose_router_logger.debug( "Received 'task.cancel'. Cancelling call w/ model={}.".format(model) ) raise except Exception as e: return e pending_tasks = [] # type: ignore async def check_response(task: asyncio.Task): nonlocal pending_tasks try: result = await task if isinstance(result, (ModelResponse, CustomStreamWrapper)): verbose_router_logger.debug( "Received successful response. Cancelling other LLM API calls." ) # If a desired response is received, cancel all other pending tasks for t in pending_tasks: t.cancel() return result except Exception: # Ignore exceptions, let the loop handle them pass finally: # Remove the task from pending tasks if it finishes try: pending_tasks.remove(task) except KeyError: pass for model in models: task = asyncio.create_task( _async_completion_no_exceptions( model=model, messages=messages, stream=stream, **kwargs ) ) pending_tasks.append(task) # Await the first task to complete successfully while pending_tasks: done, pending_tasks = await asyncio.wait( # type: ignore pending_tasks, return_when=asyncio.FIRST_COMPLETED ) for completed_task in done: result = await check_response(completed_task) if result is not None: # Return the first successful result result._hidden_params["fastest_response_batch_completion"] = True return result # If we exit the loop without returning, all tasks failed raise Exception("All tasks failed") ### SCHEDULER ### # fmt: off @overload async def schedule_acompletion( self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs ) -> ModelResponse: ... @overload async def schedule_acompletion( self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs ) -> CustomStreamWrapper: ... # fmt: on async def schedule_acompletion( self, model: str, messages: List[Dict[str, str]], priority: int, stream=False, **kwargs, ): ### FLOW ITEM ### _request_id = str(uuid.uuid4()) item = FlowItem( priority=priority, # 👈 SET PRIORITY FOR REQUEST request_id=_request_id, # 👈 SET REQUEST ID model_name="gpt-3.5-turbo", # 👈 SAME as 'Router' ) ### [fin] ### ## ADDS REQUEST TO QUEUE ## await self.scheduler.add_request(request=item) ## POLL QUEUE end_time = time.time() + self.timeout curr_time = time.time() poll_interval = self.scheduler.polling_interval # poll every 3ms make_request = False while curr_time < end_time: _healthy_deployments = await self._async_get_healthy_deployments( model=model ) make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue id=item.request_id, model_name=item.model_name, health_deployments=_healthy_deployments, ) if make_request: ## IF TRUE -> MAKE REQUEST break else: ## ELSE -> loop till default_timeout await asyncio.sleep(poll_interval) curr_time = time.time() if make_request: try: _response = await self.acompletion( model=model, messages=messages, stream=stream, **kwargs ) return _response except Exception as e: setattr(e, "priority", priority) raise e else: raise litellm.Timeout( message="Request timed out while polling queue", model=model, llm_provider="openai", ) def image_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._image_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = self.function_with_fallbacks(**kwargs) return response except Exception as e: raise e def _image_generation(self, prompt: str, model: str, **kwargs): try: verbose_router_logger.debug( f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" ) deployment = self.get_available_deployment( model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) self.routing_strategy_pre_call_checks(deployment=deployment) response = litellm.image_generation( **{ **data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs, } ) self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.image_generation(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.image_generation(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e async def aimage_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._aimage_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def _aimage_generation(self, prompt: str, model: str, **kwargs): try: verbose_router_logger.debug( f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" ) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.aimage_generation( **{ **data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs, } ) ### CONCURRENCY-SAFE RPM CHECKS ### rpm_semaphore = self._get_client( deployment=deployment, kwargs=kwargs, client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( deployment=deployment ) response = await response else: await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await response self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.aimage_generation(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e async def atranscription(self, file: BinaryIO, model: str, **kwargs): """ Example Usage: ``` from litellm import Router client = Router(model_list = [ { "model_name": "whisper", "litellm_params": { "model": "whisper-1", }, }, ]) audio_file = open("speech.mp3", "rb") transcript = await client.atranscription( model="whisper", file=audio_file ) ``` """ try: kwargs["model"] = model kwargs["file"] = file kwargs["original_function"] = self._atranscription kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def _atranscription(self, file: BinaryIO, model: str, **kwargs): try: verbose_router_logger.debug( f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" ) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.atranscription( **{ **data, "file": file, "caching": self.cache_responses, "client": model_client, **kwargs, } ) ### CONCURRENCY-SAFE RPM CHECKS ### rpm_semaphore = self._get_client( deployment=deployment, kwargs=kwargs, client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( deployment=deployment ) response = await response else: await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await response self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.atranscription(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e async def aspeech(self, model: str, input: str, voice: str, **kwargs): """ Example Usage: ``` from litellm import Router client = Router(model_list = [ { "model_name": "tts", "litellm_params": { "model": "tts-1", }, }, ]) async with client.aspeech( model="tts", voice="alloy", input="the quick brown fox jumped over the lazy dogs", api_base=None, api_key=None, organization=None, project=None, max_retries=1, timeout=600, client=None, optional_params={}, ) as response: response.stream_to_file(speech_file_path) ``` """ try: kwargs["input"] = input kwargs["voice"] = voice deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client response = await litellm.aspeech(**data, **kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def amoderation(self, model: str, input: str, **kwargs): try: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._amoderation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def _amoderation(self, model: str, input: str, **kwargs): model_name = None try: verbose_router_logger.debug( f"Inside _moderation()- model: {model}; kwargs: {kwargs}" ) deployment = await self.async_get_available_deployment( model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 timeout = ( data.get( "timeout", None ) # timeout set on litellm_params for this deployment or self.timeout # timeout set on router or kwargs.get( "timeout", None ) # this uses default_litellm_params when nothing is set ) response = await litellm.amoderation( **{ **data, "input": input, "caching": self.cache_responses, "client": model_client, "timeout": timeout, **kwargs, } ) self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.amoderation(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.amoderation(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e def text_completion( self, model: str, prompt: str, is_retry: Optional[bool] = False, is_fallback: Optional[bool] = False, is_async: Optional[bool] = False, **kwargs, ): try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._acompletion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) messages = [{"role": "user", "content": prompt}] # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment( model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None), ) data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) # call via litellm.completion() return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore except Exception as e: if self.num_retries > 0: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self.completion return self.function_with_retries(**kwargs) else: raise e async def atext_completion( self, model: str, prompt: str, is_retry: Optional[bool] = False, is_fallback: Optional[bool] = False, is_async: Optional[bool] = False, **kwargs, ): try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._atext_completion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def _atext_completion(self, model: str, prompt: str, **kwargs): try: verbose_router_logger.debug( f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" ) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": prompt}], specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), "api_base": deployment.get("litellm_params", {}).get("api_base"), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.atext_completion( **{ **data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, "timeout": self.timeout, **kwargs, } ) rpm_semaphore = self._get_client( deployment=deployment, kwargs=kwargs, client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( deployment=deployment ) response = await response else: await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await response self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.atext_completion(model={model})\033[31m Exception {str(e)}\033[0m" ) if model is not None: self.fail_calls[model] += 1 raise e def embedding( self, model: str, input: Union[str, List], is_async: Optional[bool] = False, **kwargs, ) -> Union[List[float], None]: try: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._embedding kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = self.function_with_fallbacks(**kwargs) return response except Exception as e: raise e def _embedding(self, input: Union[str, List], model: str, **kwargs): try: verbose_router_logger.debug( f"Inside embedding()- model: {model}; kwargs: {kwargs}" ) deployment = self.get_available_deployment( model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="sync" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 ### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit) self.routing_strategy_pre_call_checks(deployment=deployment) response = litellm.embedding( **{ **data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs, } ) self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e async def aembedding( self, model: str, input: Union[str, List], is_async: Optional[bool] = True, **kwargs, ) -> Union[List[float], None]: try: kwargs["model"] = model kwargs["input"] = input kwargs["original_function"] = self._aembedding kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: asyncio.create_task( send_llm_exception_alert( litellm_router_instance=self, request_kwargs=kwargs, error_traceback_str=traceback.format_exc(), original_exception=e, ) ) raise e async def _aembedding(self, input: Union[str, List], model: str, **kwargs): model_name = None try: verbose_router_logger.debug( f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" ) deployment = await self.async_get_available_deployment( model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None), ) kwargs.setdefault("metadata", {}).update( { "deployment": deployment["litellm_params"]["model"], "model_info": deployment.get("model_info", {}), "api_base": deployment.get("litellm_params", {}).get("api_base"), } ) kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs ): # prioritize model-specific params > default router params kwargs[k] = v elif k == "metadata": kwargs[k].update(v) potential_model_client = self._get_client( deployment=deployment, kwargs=kwargs, client_type="async" ) # check if provided keys == client keys # dynamic_api_key = kwargs.get("api_key", None) if ( dynamic_api_key is not None and potential_model_client is not None and dynamic_api_key != potential_model_client.api_key ): model_client = None else: model_client = potential_model_client self.total_calls[model_name] += 1 response = litellm.aembedding( **{ **data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs, } ) ### CONCURRENCY-SAFE RPM CHECKS ### rpm_semaphore = self._get_client( deployment=deployment, kwargs=kwargs, client_type="max_parallel_requests", ) if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): async with rpm_semaphore: """ - Check rpm limits before making the call - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( deployment=deployment ) response = await response else: await self.async_routing_strategy_pre_call_checks(deployment=deployment) response = await response self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m" ) return response except Exception as e: verbose_router_logger.info( f"litellm.aembedding(model={model_name})\033[31m Exception {str(e)}\033[0m" ) if model_name is not None: self.fail_calls[model_name] += 1 raise e #### ASSISTANTS API #### async def aget_assistants( self, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> AsyncCursorPage[Assistant]: if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] kwargs.update(self.assistants_config["litellm_params"]) else: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) return await litellm.aget_assistants( custom_llm_provider=custom_llm_provider, client=client, **kwargs ) async def acreate_thread( self, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> Thread: if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] kwargs.update(self.assistants_config["litellm_params"]) else: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) return await litellm.acreate_thread( custom_llm_provider=custom_llm_provider, client=client, **kwargs ) async def aget_thread( self, thread_id: str, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> Thread: if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] kwargs.update(self.assistants_config["litellm_params"]) else: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) return await litellm.aget_thread( custom_llm_provider=custom_llm_provider, thread_id=thread_id, client=client, **kwargs, ) async def a_add_message( self, thread_id: str, role: Literal["user", "assistant"], content: str, attachments: Optional[List[Attachment]] = None, metadata: Optional[dict] = None, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> OpenAIMessage: if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] kwargs.update(self.assistants_config["litellm_params"]) else: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) return await litellm.a_add_message( custom_llm_provider=custom_llm_provider, thread_id=thread_id, role=role, content=content, attachments=attachments, metadata=metadata, client=client, **kwargs, ) async def aget_messages( self, thread_id: str, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> AsyncCursorPage[OpenAIMessage]: if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] kwargs.update(self.assistants_config["litellm_params"]) else: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) return await litellm.aget_messages( custom_llm_provider=custom_llm_provider, thread_id=thread_id, client=client, **kwargs, ) async def arun_thread( self, thread_id: str, assistant_id: str, custom_llm_provider: Optional[Literal["openai", "azure"]] = None, additional_instructions: Optional[str] = None, instructions: Optional[str] = None, metadata: Optional[dict] = None, model: Optional[str] = None, stream: Optional[bool] = None, tools: Optional[Iterable[AssistantToolParam]] = None, client: Optional[Any] = None, **kwargs, ) -> Run: if custom_llm_provider is None: if self.assistants_config is not None: custom_llm_provider = self.assistants_config["custom_llm_provider"] kwargs.update(self.assistants_config["litellm_params"]) else: raise Exception( "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" ) return await litellm.arun_thread( custom_llm_provider=custom_llm_provider, thread_id=thread_id, assistant_id=assistant_id, additional_instructions=additional_instructions, instructions=instructions, metadata=metadata, model=model, stream=stream, tools=tools, client=client, **kwargs, ) #### [END] ASSISTANTS API #### async def async_function_with_fallbacks(self, *args, **kwargs): """ Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks ) try: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: raise Exception( f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}" ) response = await self.async_function_with_retries(*args, **kwargs) verbose_router_logger.debug(f"Async Response: {response}") return response except Exception as e: verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") original_exception = e fallback_model_group = None try: verbose_router_logger.debug(f"Trying to fallback b/w models") if ( hasattr(e, "status_code") and e.status_code == 400 # type: ignore and not isinstance(e, litellm.ContextWindowExceededError) ): # don't retry a malformed request raise e if ( isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None ): fallback_model_group = None for ( item ) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break if fallback_model_group is None: raise original_exception for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: kwargs["model"] = mg kwargs.setdefault("metadata", {}).update( {"model_group": mg} ) # update model_group used, if fallbacks are done response = await self.async_function_with_retries( *args, **kwargs ) verbose_router_logger.info( "Successful fallback b/w models." ) return response except Exception as e: pass elif fallbacks is not None: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") generic_fallback_idx: Optional[int] = None ## check for specific model group-specific fallbacks for idx, item in enumerate(fallbacks): if isinstance(item, dict): if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break elif list(item.keys())[0] == "*": generic_fallback_idx = idx elif isinstance(item, str): fallback_model_group = [fallbacks.pop(idx)] ## if none, check for generic fallback if ( fallback_model_group is None and generic_fallback_idx is not None ): fallback_model_group = fallbacks[generic_fallback_idx]["*"] if fallback_model_group is None: verbose_router_logger.info( f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" ) raise original_exception for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) verbose_router_logger.info( f"Falling back to model_group = {mg}" ) kwargs["model"] = mg kwargs.setdefault("metadata", {}).update( {"model_group": mg} ) # update model_group used, if fallbacks are done response = await self.async_function_with_fallbacks( *args, **kwargs ) verbose_router_logger.info( "Successful fallback b/w models." ) return response except Exception as e: raise e except Exception as e: verbose_router_logger.error(f"An exception occurred - {str(e)}") verbose_router_logger.debug(traceback.format_exc()) raise original_exception async def async_function_with_retries(self, *args, **kwargs): verbose_router_logger.debug( f"Inside async function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks ) num_retries = kwargs.pop("num_retries") verbose_router_logger.debug( f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) return response except Exception as e: current_attempt = None original_exception = e """ Retry Logic """ _healthy_deployments = await self._async_get_healthy_deployments( model=kwargs.get("model") or "", ) # raises an exception if this error should not be retries self.should_retry_this_error( error=e, healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, regular_fallbacks=fallbacks, ) # decides how long to sleep before retry _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, ) # sleeps for the length of the timeout await asyncio.sleep(_timeout) if ( self.retry_policy is not None or self.model_group_retry_policy is not None ): # get num_retries from retry policy _retry_policy_retries = self.get_num_retries_from_retry_policy( exception=original_exception, model_group=kwargs.get("model") ) if _retry_policy_retries is not None: num_retries = _retry_policy_retries ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) for current_attempt in range(num_retries): verbose_router_logger.debug( f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}" ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) if inspect.iscoroutinefunction( response ): # async errors are often returned as coroutines response = await response return response except Exception as e: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) remaining_retries = num_retries - current_attempt _healthy_deployments = await self._async_get_healthy_deployments( model=kwargs.get("model"), ) _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=remaining_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, ) await asyncio.sleep(_timeout) if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: original_exception.max_retries = num_retries original_exception.num_retries = current_attempt raise original_exception def should_retry_this_error( self, error: Exception, healthy_deployments: Optional[List] = None, context_window_fallbacks: Optional[List] = None, regular_fallbacks: Optional[List] = None, ): """ 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None 2. raise an exception for RateLimitError if - there are no fallbacks - there are no healthy deployments in the same model group """ _num_healthy_deployments = 0 if healthy_deployments is not None and isinstance(healthy_deployments, list): _num_healthy_deployments = len(healthy_deployments) ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error if ( isinstance(error, litellm.ContextWindowExceededError) and context_window_fallbacks is not None ): raise error # Error we should only retry if there are other deployments if isinstance(error, openai.RateLimitError): if ( _num_healthy_deployments <= 0 # if no healthy deployments and regular_fallbacks is not None # and fallbacks available and len(regular_fallbacks) > 0 ): raise error # then raise the error if isinstance(error, openai.AuthenticationError): """ - if other deployments available -> retry - else -> raise error """ if _num_healthy_deployments <= 0: # if no healthy deployments raise error # then raise error return True def function_with_fallbacks(self, *args, **kwargs): """ Try calling the function_with_retries If it fails after num_retries, fall back to another model group """ mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) model_group = kwargs.get("model") fallbacks = kwargs.get("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.get( "context_window_fallbacks", self.context_window_fallbacks ) try: if mock_testing_fallbacks is not None and mock_testing_fallbacks == True: raise Exception( f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}" ) response = self.function_with_retries(*args, **kwargs) return response except Exception as e: original_exception = e verbose_router_logger.debug(f"An exception occurs {original_exception}") try: if ( hasattr(e, "status_code") and e.status_code == 400 # type: ignore and not isinstance(e, litellm.ContextWindowExceededError) ): # don't retry a malformed request raise e verbose_router_logger.debug( f"Trying to fallback b/w models. Initial model group: {model_group}" ) if ( isinstance(e, litellm.ContextWindowExceededError) and context_window_fallbacks is not None ): fallback_model_group = None for ( item ) in context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break if fallback_model_group is None: raise original_exception for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg kwargs.setdefault("metadata", {}).update( {"model_group": mg} ) # update model_group used, if fallbacks are done response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: pass elif fallbacks is not None: verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") fallback_model_group = None generic_fallback_idx: Optional[int] = None for idx, item in enumerate(fallbacks): if isinstance(item, dict): if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break elif list(item.keys())[0] == "*": generic_fallback_idx = idx elif isinstance(item, str): fallback_model_group = [fallbacks.pop(idx)] ## if none, check for generic fallback if ( fallback_model_group is None and generic_fallback_idx is not None ): fallback_model_group = fallbacks[generic_fallback_idx]["*"] if fallback_model_group is None: raise original_exception for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg kwargs.setdefault("metadata", {}).update( {"model_group": mg} ) # update model_group used, if fallbacks are done response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: raise e except Exception as e: raise e raise original_exception def _time_to_sleep_before_retry( self, e: Exception, remaining_retries: int, num_retries: int, healthy_deployments: Optional[List] = None, ) -> Union[int, float]: """ Calculate back-off, then retry It should instantly retry only when: 1. there are healthy deployments in the same model group 2. there are fallbacks for the completion call """ if ( healthy_deployments is not None and isinstance(healthy_deployments, list) and len(healthy_deployments) > 0 ): return 0 if hasattr(e, "response") and hasattr(e.response, "headers"): timeout = litellm._calculate_retry_after( remaining_retries=remaining_retries, max_retries=num_retries, response_headers=e.response.headers, min_timeout=self.retry_after, ) else: timeout = litellm._calculate_retry_after( remaining_retries=remaining_retries, max_retries=num_retries, min_timeout=self.retry_after, ) return timeout def function_with_retries(self, *args, **kwargs): """ Try calling the model 3 times. Shuffle-between available deployments. """ verbose_router_logger.debug( f"Inside function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") num_retries = kwargs.pop("num_retries") fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) return response except Exception as e: current_attempt = None original_exception = e ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR _healthy_deployments = self._get_healthy_deployments( model=kwargs.get("model"), ) # raises an exception if this error should not be retries self.should_retry_this_error( error=e, healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, regular_fallbacks=fallbacks, ) # decides how long to sleep before retry _timeout = self._time_to_sleep_before_retry( e=original_exception, remaining_retries=num_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, ) ## LOGGING if num_retries > 0: kwargs = self.log_retry(kwargs=kwargs, e=original_exception) time.sleep(_timeout) for current_attempt in range(num_retries): verbose_router_logger.debug( f"retrying request. Current attempt - {current_attempt}; retries left: {num_retries}" ) try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = original_function(*args, **kwargs) return response except Exception as e: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=e) _healthy_deployments = self._get_healthy_deployments( model=kwargs.get("model"), ) remaining_retries = num_retries - current_attempt _timeout = self._time_to_sleep_before_retry( e=e, remaining_retries=remaining_retries, num_retries=num_retries, healthy_deployments=_healthy_deployments, ) time.sleep(_timeout) if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: original_exception.max_retries = num_retries original_exception.num_retries = current_attempt raise original_exception ### HELPER FUNCTIONS def deployment_callback_on_failure( self, kwargs, # kwargs to completion completion_response, # response from completion start_time, end_time, # start/end time ): try: exception = kwargs.get("exception", None) exception_type = type(exception) exception_status = getattr(exception, "status_code", "") exception_cause = getattr(exception, "__cause__", "") exception_message = getattr(exception, "message", "") exception_str = ( str(exception_type) + "Status: " + str(exception_status) + "Message: " + str(exception_cause) + str(exception_message) + "Full exception" + str(exception) ) model_name = kwargs.get("model", None) # i.e. gpt35turbo custom_llm_provider = kwargs.get("litellm_params", {}).get( "custom_llm_provider", None ) # i.e. azure metadata = kwargs.get("litellm_params", {}).get("metadata", None) _model_info = kwargs.get("litellm_params", {}).get("model_info", {}) exception_response = getattr(exception, "response", {}) exception_headers = getattr(exception_response, "headers", None) _time_to_cooldown = self.cooldown_time if exception_headers is not None: _time_to_cooldown = ( litellm.utils._get_retry_after_from_exception_header( response_headers=exception_headers ) ) if _time_to_cooldown is None or _time_to_cooldown < 0: # if the response headers did not read it -> set to default cooldown time _time_to_cooldown = self.cooldown_time if isinstance(_model_info, dict): deployment_id = _model_info.get("id", None) self._set_cooldown_deployments( exception_status=exception_status, original_exception=exception, deployment=deployment_id, time_to_cooldown=_time_to_cooldown, ) # setting deployment_id in cooldown deployments if custom_llm_provider: model_name = f"{custom_llm_provider}/{model_name}" except Exception as e: raise e def log_retry(self, kwargs: dict, e: Exception) -> dict: """ When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing """ try: # Log failed model as the previous model previous_model = { "exception_type": type(e).__name__, "exception_string": str(e), } for ( k, v, ) in ( kwargs.items() ): # log everything in kwargs except the old previous_models value - prevent nesting if k not in ["metadata", "messages", "original_function"]: previous_model[k] = v elif k == "metadata" and isinstance(v, dict): previous_model["metadata"] = {} # type: ignore for metadata_k, metadata_v in kwargs["metadata"].items(): if metadata_k != "previous_models": previous_model[k][metadata_k] = metadata_v # type: ignore # check current size of self.previous_models, if it's larger than 3, remove the first element if len(self.previous_models) > 3: self.previous_models.pop(0) self.previous_models.append(previous_model) kwargs["metadata"]["previous_models"] = self.previous_models return kwargs except Exception as e: raise e def _update_usage(self, deployment_id: str): """ Update deployment rpm for that minute """ rpm_key = deployment_id request_count = self.cache.get_cache(key=rpm_key, local_only=True) if request_count is None: request_count = 1 self.cache.set_cache( key=rpm_key, value=request_count, local_only=True, ttl=60 ) # only store for 60s else: request_count += 1 self.cache.set_cache( key=rpm_key, value=request_count, local_only=True ) # don't change existing ttl def _is_cooldown_required(self, exception_status: Union[str, int]): """ A function to determine if a cooldown is required based on the exception status. Parameters: exception_status (Union[str, int]): The status of the exception. Returns: bool: True if a cooldown is required, False otherwise. """ try: if isinstance(exception_status, str): exception_status = int(exception_status) if exception_status >= 400 and exception_status < 500: if exception_status == 429: # Cool down 429 Rate Limit Errors return True elif exception_status == 401: # Cool down 401 Auth Errors return True elif exception_status == 408: return True elif exception_status == 404: return True else: # Do NOT cool down all other 4XX Errors return False else: # should cool down for all other errors return True except: # Catch all - if any exceptions default to cooling down return True def _set_cooldown_deployments( self, original_exception: Any, exception_status: Union[str, int], deployment: Optional[str] = None, time_to_cooldown: Optional[float] = None, ): """ Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute or the exception is not one that should be immediately retried (e.g. 401) """ if deployment is None: return if self._is_cooldown_required(exception_status=exception_status) == False: return _allowed_fails = self.get_allowed_fails_from_policy( exception=original_exception, ) allowed_fails = _allowed_fails or self.allowed_fails dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") # get current fails for deployment # update the number of failed calls # if it's > allowed fails # cooldown deployment current_fails = self.failed_calls.get_cache(key=deployment) or 0 updated_fails = current_fails + 1 verbose_router_logger.debug( f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}" ) cooldown_time = self.cooldown_time or 1 if time_to_cooldown is not None: cooldown_time = time_to_cooldown if isinstance(exception_status, str): try: exception_status = int(exception_status) except Exception as e: verbose_router_logger.debug( "Unable to cast exception status to int {}. Defaulting to status=500.".format( exception_status ) ) exception_status = 500 _should_retry = litellm._should_retry(status_code=exception_status) if updated_fails > allowed_fails or _should_retry == False: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls cached_value = self.cache.get_cache(key=cooldown_key) verbose_router_logger.debug(f"adding {deployment} to cooldown models") # update value try: if deployment in cached_value: pass else: cached_value = cached_value + [deployment] # save updated value self.cache.set_cache( value=cached_value, key=cooldown_key, ttl=cooldown_time ) except: cached_value = [deployment] # save updated value self.cache.set_cache( value=cached_value, key=cooldown_key, ttl=cooldown_time ) self.send_deployment_cooldown_alert( deployment_id=deployment, exception_status=exception_status, cooldown_time=cooldown_time, ) else: self.failed_calls.set_cache( key=deployment, value=updated_fails, ttl=cooldown_time ) async def _async_get_cooldown_deployments(self): """ Async implementation of '_get_cooldown_deployments' """ dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # ---------------------- # Return cooldown models # ---------------------- cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or [] verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models def _get_cooldown_deployments(self): """ Get the list of models being cooled down for this minute """ dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # ---------------------- # Return cooldown models # ---------------------- cooldown_models = self.cache.get_cache(key=cooldown_key) or [] verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}") return cooldown_models def _get_healthy_deployments(self, model: str): _all_deployments: list = [] try: _, _all_deployments = self._common_checks_available_deployment( # type: ignore model=model, ) if type(_all_deployments) == dict: return [] except: pass unhealthy_deployments = self._get_cooldown_deployments() healthy_deployments: list = [] for deployment in _all_deployments: if deployment["model_info"]["id"] in unhealthy_deployments: continue else: healthy_deployments.append(deployment) return healthy_deployments async def _async_get_healthy_deployments(self, model: str): _all_deployments: list = [] try: _, _all_deployments = self._common_checks_available_deployment( # type: ignore model=model, ) if type(_all_deployments) == dict: return [] except: pass unhealthy_deployments = await self._async_get_cooldown_deployments() healthy_deployments: list = [] for deployment in _all_deployments: if deployment["model_info"]["id"] in unhealthy_deployments: continue else: healthy_deployments.append(deployment) return healthy_deployments def routing_strategy_pre_call_checks(self, deployment: dict): """ Mimics 'async_routing_strategy_pre_call_checks' Ensures consistent update rpm implementation for 'usage-based-routing-v2' Returns: - None Raises: - Rate Limit Exception - If the deployment is over it's tpm/rpm limits """ for _callback in litellm.callbacks: if isinstance(_callback, CustomLogger): response = _callback.pre_call_check(deployment) async def async_routing_strategy_pre_call_checks(self, deployment: dict): """ For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. -> makes the calls concurrency-safe, when rpm limits are set for a deployment Returns: - None Raises: - Rate Limit Exception - If the deployment is over it's tpm/rpm limits """ for _callback in litellm.callbacks: if isinstance(_callback, CustomLogger): try: response = await _callback.async_pre_call_check(deployment) except litellm.RateLimitError as e: self._set_cooldown_deployments( exception_status=e.status_code, original_exception=e, deployment=deployment["model_info"]["id"], time_to_cooldown=self.cooldown_time, ) raise e except Exception as e: raise e def set_client(self, model: dict): """ - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 """ client_ttl = self.client_ttl litellm_params = model.get("litellm_params", {}) model_name = litellm_params.get("model") model_id = model["model_info"]["id"] # ### IF RPM SET - initialize a semaphore ### rpm = litellm_params.get("rpm", None) tpm = litellm_params.get("tpm", None) max_parallel_requests = litellm_params.get("max_parallel_requests", None) calculated_max_parallel_requests = calculate_max_parallel_requests( rpm=rpm, max_parallel_requests=max_parallel_requests, tpm=tpm, default_max_parallel_requests=self.default_max_parallel_requests, ) if calculated_max_parallel_requests: semaphore = asyncio.Semaphore(calculated_max_parallel_requests) cache_key = f"{model_id}_max_parallel_requests_client" self.cache.set_cache( key=cache_key, value=semaphore, local_only=True, ) #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" default_api_base = None default_api_key = None if custom_llm_provider in litellm.openai_compatible_providers: _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( model=model_name ) default_api_base = api_base default_api_key = api_key if ( model_name in litellm.open_ai_chat_completion_models or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider == "azure" or custom_llm_provider == "azure_text" or custom_llm_provider == "custom_openai" or custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or "ft:gpt-3.5-turbo" in model_name or model_name in litellm.open_ai_embedding_models ): is_azure_ai_studio_model: bool = False if custom_llm_provider == "azure": if litellm.utils._is_non_openai_azure_model(model_name): is_azure_ai_studio_model = True custom_llm_provider = "openai" # remove azure prefx from model_name model_name = model_name.replace("azure/", "") # glorified / complicated reading of configs # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # we do this here because we init clients for Azure, OpenAI and we need to set the right key api_key = litellm_params.get("api_key") or default_api_key if ( api_key and isinstance(api_key, str) and api_key.startswith("os.environ/") ): api_key_env_name = api_key.replace("os.environ/", "") api_key = litellm.get_secret(api_key_env_name) litellm_params["api_key"] = api_key api_base = litellm_params.get("api_base") base_url = litellm_params.get("base_url") api_base = ( api_base or base_url or default_api_base ) # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): api_base_env_name = api_base.replace("os.environ/", "") api_base = litellm.get_secret(api_base_env_name) litellm_params["api_base"] = api_base ## AZURE AI STUDIO MISTRAL CHECK ## """ Make sure api base ends in /v1/ if not, add it - https://github.com/BerriAI/litellm/issues/2279 """ if ( is_azure_ai_studio_model == True and api_base is not None and isinstance(api_base, str) and not api_base.endswith("/v1/") ): # check if it ends with a trailing slash if api_base.endswith("/"): api_base += "v1/" elif api_base.endswith("/v1"): api_base += "/" else: api_base += "/v1/" api_version = litellm_params.get("api_version") if api_version and api_version.startswith("os.environ/"): api_version_env_name = api_version.replace("os.environ/", "") api_version = litellm.get_secret(api_version_env_name) litellm_params["api_version"] = api_version timeout = litellm_params.pop("timeout", None) or litellm.request_timeout if isinstance(timeout, str) and timeout.startswith("os.environ/"): timeout_env_name = timeout.replace("os.environ/", "") timeout = litellm.get_secret(timeout_env_name) litellm_params["timeout"] = timeout stream_timeout = litellm_params.pop( "stream_timeout", timeout ) # if no stream_timeout is set, default to timeout if isinstance(stream_timeout, str) and stream_timeout.startswith( "os.environ/" ): stream_timeout_env_name = stream_timeout.replace("os.environ/", "") stream_timeout = litellm.get_secret(stream_timeout_env_name) litellm_params["stream_timeout"] = stream_timeout max_retries = litellm_params.pop( "max_retries", 0 ) # router handles retry logic if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): max_retries_env_name = max_retries.replace("os.environ/", "") max_retries = litellm.get_secret(max_retries_env_name) litellm_params["max_retries"] = max_retries # proxy support import os import httpx # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. http_proxy = os.getenv("HTTP_PROXY", None) https_proxy = os.getenv("HTTPS_PROXY", None) no_proxy = os.getenv("NO_PROXY", None) # Create the proxies dictionary only if the environment variables are set. sync_proxy_mounts = None async_proxy_mounts = None if http_proxy is not None and https_proxy is not None: sync_proxy_mounts = { "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), } async_proxy_mounts = { "http://": httpx.AsyncHTTPTransport( proxy=httpx.Proxy(url=http_proxy) ), "https://": httpx.AsyncHTTPTransport( proxy=httpx.Proxy(url=https_proxy) ), } # assume no_proxy is a list of comma separated urls if no_proxy is not None and isinstance(no_proxy, str): no_proxy_urls = no_proxy.split(",") for url in no_proxy_urls: # set no-proxy support for specific urls sync_proxy_mounts[url] = None # type: ignore async_proxy_mounts[url] = None # type: ignore organization = litellm_params.get("organization", None) if isinstance(organization, str) and organization.startswith("os.environ/"): organization_env_name = organization.replace("os.environ/", "") organization = litellm.get_secret(organization_env_name) litellm_params["organization"] = organization if "azure" in model_name: if api_base is None or not isinstance(api_base, str): filtered_litellm_params = { k: v for k, v in model["litellm_params"].items() if k != "api_key" } _filtered_model = { "model_name": model["model_name"], "litellm_params": filtered_litellm_params, } raise ValueError( f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" ) azure_ad_token = litellm_params.get("azure_ad_token") if azure_ad_token is not None: if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) if api_version is None: api_version = "2023-07-01-preview" if "gateway.ai.cloudflare.com" in api_base: if not api_base.endswith("/"): api_base += "/" azure_model = model_name.replace("azure/", "") api_base += f"{azure_model}" cache_key = f"{model_id}_async_client" _client = openai.AsyncAzureOpenAI( api_key=api_key, azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=timeout, max_retries=max_retries, http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr cache_key = f"{model_id}_client" _client = openai.AzureOpenAI( # type: ignore api_key=api_key, azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=timeout, max_retries=max_retries, http_client=httpx.Client( transport=CustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr # streaming clients can have diff timeouts cache_key = f"{model_id}_stream_async_client" _client = openai.AsyncAzureOpenAI( # type: ignore api_key=api_key, azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=stream_timeout, max_retries=max_retries, http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr cache_key = f"{model_id}_stream_client" _client = openai.AzureOpenAI( # type: ignore api_key=api_key, azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=stream_timeout, max_retries=max_retries, http_client=httpx.Client( transport=CustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr else: _api_key = api_key if _api_key is not None and isinstance(_api_key, str): # only show first 5 chars of api_key _api_key = _api_key[:8] + "*" * 15 verbose_router_logger.debug( f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" ) azure_client_params = { "api_key": api_key, "azure_endpoint": api_base, "api_version": api_version, "azure_ad_token": azure_ad_token, } from litellm.llms.azure import select_azure_base_url_or_endpoint # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client azure_client_params = select_azure_base_url_or_endpoint( azure_client_params ) cache_key = f"{model_id}_async_client" _client = openai.AsyncAzureOpenAI( # type: ignore **azure_client_params, timeout=timeout, max_retries=max_retries, http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr cache_key = f"{model_id}_client" _client = openai.AzureOpenAI( # type: ignore **azure_client_params, timeout=timeout, max_retries=max_retries, http_client=httpx.Client( transport=CustomHTTPTransport( verify=litellm.ssl_verify, limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), ), mounts=sync_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr # streaming clients should have diff timeouts cache_key = f"{model_id}_stream_async_client" _client = openai.AsyncAzureOpenAI( # type: ignore **azure_client_params, timeout=stream_timeout, max_retries=max_retries, http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr cache_key = f"{model_id}_stream_client" _client = openai.AzureOpenAI( # type: ignore **azure_client_params, timeout=stream_timeout, max_retries=max_retries, http_client=httpx.Client( transport=CustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr else: _api_key = api_key # type: ignore if _api_key is not None and isinstance(_api_key, str): # only show first 5 chars of api_key _api_key = _api_key[:8] + "*" * 15 verbose_router_logger.debug( f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" ) cache_key = f"{model_id}_async_client" _client = openai.AsyncOpenAI( # type: ignore api_key=api_key, base_url=api_base, timeout=timeout, max_retries=max_retries, organization=organization, http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr cache_key = f"{model_id}_client" _client = openai.OpenAI( # type: ignore api_key=api_key, base_url=api_base, timeout=timeout, max_retries=max_retries, organization=organization, http_client=httpx.Client( transport=CustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr # streaming clients should have diff timeouts cache_key = f"{model_id}_stream_async_client" _client = openai.AsyncOpenAI( # type: ignore api_key=api_key, base_url=api_base, timeout=stream_timeout, max_retries=max_retries, organization=organization, http_client=httpx.AsyncClient( transport=AsyncCustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=async_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr # streaming clients should have diff timeouts cache_key = f"{model_id}_stream_client" _client = openai.OpenAI( # type: ignore api_key=api_key, base_url=api_base, timeout=stream_timeout, max_retries=max_retries, organization=organization, http_client=httpx.Client( transport=CustomHTTPTransport( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 ), verify=litellm.ssl_verify, ), mounts=sync_proxy_mounts, ), # type: ignore ) self.cache.set_cache( key=cache_key, value=_client, ttl=client_ttl, local_only=True, ) # cache for 1 hr def _generate_model_id(self, model_group: str, litellm_params: dict): """ Helper function to consistently generate the same id for a deployment - create a string from all the litellm params - hash - use hash as id """ concat_str = model_group for k, v in litellm_params.items(): if isinstance(k, str): concat_str += k elif isinstance(k, dict): concat_str += json.dumps(k) else: concat_str += str(k) if isinstance(v, str): concat_str += v elif isinstance(v, dict): concat_str += json.dumps(v) else: concat_str += str(v) hash_object = hashlib.sha256(concat_str.encode()) return hash_object.hexdigest() def set_model_list(self, model_list: list): original_model_list = copy.deepcopy(model_list) self.model_list = [] # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works import os for model in original_model_list: _model_name = model.pop("model_name") _litellm_params = model.pop("litellm_params") ## check if litellm params in os.environ if isinstance(_litellm_params, dict): for k, v in _litellm_params.items(): if isinstance(v, str) and v.startswith("os.environ/"): _litellm_params[k] = litellm.get_secret(v) _model_info: dict = model.pop("model_info", {}) # check if model info has id if "id" not in _model_info: _id = self._generate_model_id(_model_name, _litellm_params) _model_info["id"] = _id deployment = Deployment( **model, model_name=_model_name, litellm_params=_litellm_params, model_info=_model_info, ) deployment = self._add_deployment(deployment=deployment) model = deployment.to_json(exclude_none=True) self.model_list.append(model) verbose_router_logger.debug(f"\nInitialized Model List {self.model_list}") self.model_names = [m["model_name"] for m in model_list] def _add_deployment(self, deployment: Deployment) -> Deployment: import os #### DEPLOYMENT NAMES INIT ######## self.deployment_names.append(deployment.litellm_params.model) ############ Users can either pass tpm/rpm as a litellm_param or a router param ########### # for get_available_deployment, we use the litellm_param["rpm"] # in this snippet we also set rpm to be a litellm_param if ( deployment.litellm_params.rpm is None and getattr(deployment, "rpm", None) is not None ): deployment.litellm_params.rpm = getattr(deployment, "rpm") if ( deployment.litellm_params.tpm is None and getattr(deployment, "tpm", None) is not None ): deployment.litellm_params.tpm = getattr(deployment, "tpm") #### VALIDATE MODEL ######## # check if model provider in supported providers ( _model, custom_llm_provider, dynamic_api_key, api_base, ) = litellm.get_llm_provider( model=deployment.litellm_params.model, custom_llm_provider=deployment.litellm_params.get( "custom_llm_provider", None ), ) # Check if user is trying to use model_name == "*" # this is a catch all model for their specific api key if deployment.model_name == "*": self.default_deployment = deployment.to_json(exclude_none=True) # Azure GPT-Vision Enhancements, users can pass os.environ/ data_sources = deployment.litellm_params.get("dataSources", []) or [] for data_source in data_sources: params = data_source.get("parameters", {}) for param_key in ["endpoint", "key"]: # if endpoint or key set for Azure GPT Vision Enhancements, check if it's an env var if param_key in params and params[param_key].startswith("os.environ/"): env_name = params[param_key].replace("os.environ/", "") params[param_key] = os.environ.get(env_name, "") # done reading model["litellm_params"] if custom_llm_provider not in litellm.provider_list: raise Exception(f"Unsupported provider - {custom_llm_provider}") # init OpenAI, Azure clients self.set_client(model=deployment.to_json(exclude_none=True)) # set region (if azure model) ## PREVIEW FEATURE ## if litellm.enable_preview_features == True: print("Auto inferring region") # noqa """ Hiding behind a feature flag When there is a large amount of LLM deployments this makes startup times blow up """ try: if ( "azure" in deployment.litellm_params.model and deployment.litellm_params.region_name is None ): region = litellm.utils.get_model_region( litellm_params=deployment.litellm_params, mode=None ) deployment.litellm_params.region_name = region except Exception as e: verbose_router_logger.debug( "Unable to get the region for azure model - {}, {}".format( deployment.litellm_params.model, str(e) ) ) pass # [NON-BLOCKING] return deployment def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: """ Parameters: - deployment: Deployment - the deployment to be added to the Router Returns: - The added deployment - OR None (if deployment already exists) """ # check if deployment already exists if deployment.model_info.id in self.get_model_ids(): return None # add to model list _deployment = deployment.to_json(exclude_none=True) self.model_list.append(_deployment) # initialize client self._add_deployment(deployment=deployment) # add to model names self.model_names.append(deployment.model_name) return deployment def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: """ Add or update deployment Parameters: - deployment: Deployment - the deployment to be added to the Router Returns: - The added/updated deployment """ # check if deployment already exists _deployment_model_id = deployment.model_info.id or "" _deployment_on_router: Optional[Deployment] = self.get_deployment( model_id=_deployment_model_id ) if _deployment_on_router is not None: # deployment with this model_id exists on the router if deployment.litellm_params == _deployment_on_router.litellm_params: # No need to update return None # if there is a new litellm param -> then update the deployment # remove the previous deployment removal_idx: Optional[int] = None for idx, model in enumerate(self.model_list): if model["model_info"]["id"] == deployment.model_info.id: removal_idx = idx if removal_idx is not None: self.model_list.pop(removal_idx) else: # if the model_id is not in router self.add_deployment(deployment=deployment) return deployment def delete_deployment(self, id: str) -> Optional[Deployment]: """ Parameters: - id: str - the id of the deployment to be deleted Returns: - The deleted deployment - OR None (if deleted deployment not found) """ deployment_idx = None for idx, m in enumerate(self.model_list): if m["model_info"]["id"] == id: deployment_idx = idx try: if deployment_idx is not None: item = self.model_list.pop(deployment_idx) return item else: return None except: return None def get_deployment(self, model_id: str) -> Optional[Deployment]: """ Returns -> Deployment or None Raise Exception -> if model found in invalid format """ for model in self.model_list: if "model_info" in model and "id" in model["model_info"]: if model_id == model["model_info"]["id"]: if isinstance(model, dict): return Deployment(**model) elif isinstance(model, Deployment): return model else: raise Exception("Model invalid format - {}".format(type(model))) return None def get_model_info(self, id: str) -> Optional[dict]: """ For a given model id, return the model info """ for model in self.model_list: if "model_info" in model and "id" in model["model_info"]: if id == model["model_info"]["id"]: return model return None def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: """ For a given model group name, return the combined model info Returns: - ModelGroupInfo if able to construct a model group - None if error constructing model group info """ model_group_info: Optional[ModelGroupInfo] = None for model in self.model_list: if "model_name" in model and model["model_name"] == model_group: # model in model group found # litellm_params = LiteLLM_Params(**model["litellm_params"]) # get model info try: model_info = litellm.get_model_info(model=litellm_params.model) except Exception: model_info = None # get llm provider model, llm_provider = "", "" try: model, llm_provider, _, _ = litellm.get_llm_provider( model=litellm_params.model, custom_llm_provider=litellm_params.custom_llm_provider, ) except litellm.exceptions.BadRequestError as e: verbose_router_logger.error( "litellm.router.py::get_model_group_info() - {}".format(str(e)) ) if model_info is None: supported_openai_params = litellm.get_supported_openai_params( model=model, custom_llm_provider=llm_provider ) if supported_openai_params is None: supported_openai_params = [] model_info = ModelMapInfo( max_tokens=None, max_input_tokens=None, max_output_tokens=None, input_cost_per_token=0, output_cost_per_token=0, litellm_provider=llm_provider, mode="chat", supported_openai_params=supported_openai_params, ) if model_group_info is None: model_group_info = ModelGroupInfo( model_group=model_group, providers=[llm_provider], **model_info # type: ignore ) else: # if max_input_tokens > curr # if max_output_tokens > curr # if input_cost_per_token > curr # if output_cost_per_token > curr # supports_parallel_function_calling == True # supports_vision == True # supports_function_calling == True if llm_provider not in model_group_info.providers: model_group_info.providers.append(llm_provider) if ( model_info.get("max_input_tokens", None) is not None and model_info["max_input_tokens"] is not None and ( model_group_info.max_input_tokens is None or model_info["max_input_tokens"] > model_group_info.max_input_tokens ) ): model_group_info.max_input_tokens = model_info[ "max_input_tokens" ] if ( model_info.get("max_output_tokens", None) is not None and model_info["max_output_tokens"] is not None and ( model_group_info.max_output_tokens is None or model_info["max_output_tokens"] > model_group_info.max_output_tokens ) ): model_group_info.max_output_tokens = model_info[ "max_output_tokens" ] if model_info.get("input_cost_per_token", None) is not None and ( model_group_info.input_cost_per_token is None or model_info["input_cost_per_token"] > model_group_info.input_cost_per_token ): model_group_info.input_cost_per_token = model_info[ "input_cost_per_token" ] if model_info.get("output_cost_per_token", None) is not None and ( model_group_info.output_cost_per_token is None or model_info["output_cost_per_token"] > model_group_info.output_cost_per_token ): model_group_info.output_cost_per_token = model_info[ "output_cost_per_token" ] if ( model_info.get("supports_parallel_function_calling", None) is not None and model_info["supports_parallel_function_calling"] is True # type: ignore ): model_group_info.supports_parallel_function_calling = True if ( model_info.get("supports_vision", None) is not None and model_info["supports_vision"] is True # type: ignore ): model_group_info.supports_vision = True if ( model_info.get("supports_function_calling", None) is not None and model_info["supports_function_calling"] is True # type: ignore ): model_group_info.supports_function_calling = True if ( model_info.get("supported_openai_params", None) is not None and model_info["supported_openai_params"] is not None ): model_group_info.supported_openai_params = model_info[ "supported_openai_params" ] return model_group_info def get_model_ids(self) -> List[str]: """ Returns list of model id's. """ ids = [] for model in self.model_list: if "model_info" in model and "id" in model["model_info"]: id = model["model_info"]["id"] ids.append(id) return ids def get_model_names(self) -> List[str]: return self.model_names def get_model_list(self): if hasattr(self, "model_list"): return self.model_list return None def get_settings(self): """ Get router settings method, returns a dictionary of the settings and their values. For example get the set values for routing_strategy_args, routing_strategy, allowed_fails, cooldown_time, num_retries, timeout, max_retries, retry_after """ _all_vars = vars(self) _settings_to_return = {} vars_to_include = [ "routing_strategy_args", "routing_strategy", "allowed_fails", "cooldown_time", "num_retries", "timeout", "max_retries", "retry_after", "fallbacks", "context_window_fallbacks", "model_group_retry_policy", ] for var in vars_to_include: if var in _all_vars: _settings_to_return[var] = _all_vars[var] if ( var == "routing_strategy_args" and self.routing_strategy == "latency-based-routing" ): _settings_to_return[var] = self.lowestlatency_logger.routing_args.json() return _settings_to_return def update_settings(self, **kwargs): # only the following settings are allowed to be configured _allowed_settings = [ "routing_strategy_args", "routing_strategy", "allowed_fails", "cooldown_time", "num_retries", "timeout", "max_retries", "retry_after", "fallbacks", "context_window_fallbacks", "model_group_retry_policy", ] _int_settings = [ "timeout", "num_retries", "retry_after", "allowed_fails", "cooldown_time", ] _existing_router_settings = self.get_settings() for var in kwargs: if var in _allowed_settings: if var in _int_settings: _casted_value = int(kwargs[var]) setattr(self, var, _casted_value) else: # only run routing strategy init if it has changed if ( var == "routing_strategy" and _existing_router_settings["routing_strategy"] != kwargs[var] ): self.routing_strategy_init( routing_strategy=kwargs[var], routing_strategy_args=kwargs.get( "routing_strategy_args", {} ), ) setattr(self, var, kwargs[var]) else: verbose_router_logger.debug("Setting {} is not allowed".format(var)) verbose_router_logger.debug(f"Updated Router settings: {self.get_settings()}") def _get_client(self, deployment, kwargs, client_type=None): """ Returns the appropriate client based on the given deployment, kwargs, and client_type. Parameters: deployment (dict): The deployment dictionary containing the clients. kwargs (dict): The keyword arguments passed to the function. client_type (str): The type of client to return. Returns: The appropriate client based on the given client_type and kwargs. """ model_id = deployment["model_info"]["id"] if client_type == "max_parallel_requests": cache_key = "{}_max_parallel_requests_client".format(model_id) client = self.cache.get_cache(key=cache_key, local_only=True) return client elif client_type == "async": if kwargs.get("stream") == True: cache_key = f"{model_id}_stream_async_client" client = self.cache.get_cache(key=cache_key, local_only=True) if client is None: """ Re-initialize the client """ self.set_client(model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: cache_key = f"{model_id}_async_client" client = self.cache.get_cache(key=cache_key, local_only=True) if client is None: """ Re-initialize the client """ self.set_client(model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: if kwargs.get("stream") == True: cache_key = f"{model_id}_stream_client" client = self.cache.get_cache(key=cache_key) if client is None: """ Re-initialize the client """ self.set_client(model=deployment) client = self.cache.get_cache(key=cache_key) return client else: cache_key = f"{model_id}_client" client = self.cache.get_cache(key=cache_key) if client is None: """ Re-initialize the client """ self.set_client(model=deployment) client = self.cache.get_cache(key=cache_key) return client def _pre_call_checks( self, model: str, healthy_deployments: List, messages: List[Dict[str, str]], request_kwargs: Optional[dict] = None, ): """ Filter out model in model group, if: - model context window < message length - filter models above rpm limits - if region given, filter out models not in that region / unknown region - [TODO] function call and model doesn't support function calling """ verbose_router_logger.debug( f"Starting Pre-call checks for deployments in model={model}" ) _returned_deployments = copy.deepcopy(healthy_deployments) invalid_model_indices = [] try: input_tokens = litellm.token_counter(messages=messages) except Exception as e: return _returned_deployments _context_window_error = False _rate_limit_error = False ## get model group RPM ## dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") rpm_key = f"{model}:rpm:{current_minute}" model_group_cache = ( self.cache.get_cache(key=rpm_key, local_only=True) or {} ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. for idx, deployment in enumerate(_returned_deployments): # see if we have the info for this model try: base_model = deployment.get("model_info", {}).get("base_model", None) if base_model is None: base_model = deployment.get("litellm_params", {}).get( "base_model", None ) model = base_model or deployment.get("litellm_params", {}).get( "model", None ) model_info = litellm.get_model_info(model=model) if ( isinstance(model_info, dict) and model_info.get("max_input_tokens", None) is not None ): if ( isinstance(model_info["max_input_tokens"], int) and input_tokens > model_info["max_input_tokens"] ): invalid_model_indices.append(idx) _context_window_error = True continue except Exception as e: verbose_router_logger.debug("An error occurs - {}".format(str(e))) _litellm_params = deployment.get("litellm_params", {}) model_id = deployment.get("model_info", {}).get("id", "") ## RPM CHECK ## ### get local router cache ### current_request_cache_local = ( self.cache.get_cache(key=model_id, local_only=True) or 0 ) ### get usage based cache ### if ( isinstance(model_group_cache, dict) and self.routing_strategy != "usage-based-routing-v2" ): model_group_cache[model_id] = model_group_cache.get(model_id, 0) current_request = max( current_request_cache_local, model_group_cache[model_id] ) if ( isinstance(_litellm_params, dict) and _litellm_params.get("rpm", None) is not None ): if ( isinstance(_litellm_params["rpm"], int) and _litellm_params["rpm"] <= current_request ): invalid_model_indices.append(idx) _rate_limit_error = True continue ## REGION CHECK ## if ( request_kwargs is not None and request_kwargs.get("allowed_model_region") is not None and request_kwargs["allowed_model_region"] == "eu" ): if _litellm_params.get("region_name") is not None and isinstance( _litellm_params["region_name"], str ): # check if in allowed_model_region if ( _is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params)) == False ): invalid_model_indices.append(idx) continue else: verbose_router_logger.debug( "Filtering out model - {}, as model_region=None, and allowed_model_region={}".format( model_id, request_kwargs.get("allowed_model_region") ) ) # filter out since region unknown, and user wants to filter for specific region invalid_model_indices.append(idx) continue ## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param if request_kwargs is not None and litellm.drop_params == False: # get supported params model, custom_llm_provider, _, _ = litellm.get_llm_provider( model=model, litellm_params=LiteLLM_Params(**_litellm_params) ) supported_openai_params = litellm.get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider ) if supported_openai_params is None: continue else: # check the non-default openai params in request kwargs non_default_params = litellm.utils.get_non_default_params( passed_params=request_kwargs ) special_params = ["response_format"] # check if all params are supported for k, v in non_default_params.items(): if k not in supported_openai_params and k in special_params: # if not -> invalid model verbose_router_logger.debug( f"INVALID MODEL INDEX @ REQUEST KWARG FILTERING, k={k}" ) invalid_model_indices.append(idx) if len(invalid_model_indices) == len(_returned_deployments): """ - no healthy deployments available b/c context window checks or rate limit error - First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check) """ if _rate_limit_error == True: # allow generic fallback logic to take place raise ValueError( f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds." ) elif _context_window_error == True: raise litellm.ContextWindowExceededError( message="Context Window exceeded for given call", model=model, llm_provider="", response=httpx.Response( status_code=400, request=httpx.Request("GET", "https://example.com"), ), ) if len(invalid_model_indices) > 0: for idx in reversed(invalid_model_indices): _returned_deployments.pop(idx) ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) if len(_returned_deployments) > 0: _returned_deployments = litellm.utils._get_order_filtered_deployments( _returned_deployments ) return _returned_deployments def _common_checks_available_deployment( self, model: str, messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, ) -> Tuple[str, Union[list, dict]]: """ Common checks for 'get_available_deployment' across sync + async call. If 'healthy_deployments' returned is None, this means the user chose a specific deployment Returns - Dict, if specific model chosen - List, if multiple models chosen """ # check if aliases set on litellm model alias map if specific_deployment == True: # users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment for deployment in self.model_list: deployment_model = deployment.get("litellm_params").get("model") if deployment_model == model: # User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2 # return the first deployment where the `model` matches the specificed deployment name return deployment_model, deployment raise ValueError( f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}" ) if model in self.model_group_alias: verbose_router_logger.debug( f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}" ) model = self.model_group_alias[model] if model not in self.model_names and self.default_deployment is not None: updated_deployment = copy.deepcopy( self.default_deployment ) # self.default_deployment updated_deployment["litellm_params"]["model"] = model return model, updated_deployment ## get healthy deployments ### get all deployments healthy_deployments = [m for m in self.model_list if m["model_name"] == model] if len(healthy_deployments) == 0: # check if the user sent in a deployment name instead healthy_deployments = [ m for m in self.model_list if m["litellm_params"]["model"] == model ] litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") if len(healthy_deployments) == 0: raise ValueError( f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds" ) if litellm.model_alias_map and model in litellm.model_alias_map: model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in return model, healthy_deployments async def async_get_available_deployment( self, model: str, messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, request_kwargs: Optional[Dict] = None, ): """ Async implementation of 'get_available_deployments'. Allows all cache calls to be made async => 10x perf impact (8rps -> 100 rps). """ if ( self.routing_strategy != "usage-based-routing-v2" and self.routing_strategy != "simple-shuffle" and self.routing_strategy != "cost-based-routing" ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. return self.get_available_deployment( model=model, messages=messages, input=input, specific_deployment=specific_deployment, request_kwargs=request_kwargs, ) model, healthy_deployments = self._common_checks_available_deployment( model=model, messages=messages, input=input, specific_deployment=specific_deployment, ) # type: ignore if isinstance(healthy_deployments, dict): return healthy_deployments # filter out the deployments currently cooling down deployments_to_remove = [] # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] cooldown_deployments = await self._async_get_cooldown_deployments() verbose_router_logger.debug( f"async cooldown deployments: {cooldown_deployments}" ) # Find deployments in model_list whose model_id is cooling down for deployment in healthy_deployments: deployment_id = deployment["model_info"]["id"] if deployment_id in cooldown_deployments: deployments_to_remove.append(deployment) # remove unhealthy deployments from healthy deployments for deployment in deployments_to_remove: healthy_deployments.remove(deployment) # filter pre-call checks _allowed_model_region = ( request_kwargs.get("allowed_model_region") if request_kwargs is not None else None ) if self.enable_pre_call_checks and messages is not None: healthy_deployments = self._pre_call_checks( model=model, healthy_deployments=healthy_deployments, messages=messages, request_kwargs=request_kwargs, ) if len(healthy_deployments) == 0: if _allowed_model_region is None: _allowed_model_region = "n/a" raise ValueError( f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}" ) if ( self.routing_strategy == "usage-based-routing-v2" and self.lowesttpm_logger_v2 is not None ): deployment = await self.lowesttpm_logger_v2.async_get_available_deployments( model_group=model, healthy_deployments=healthy_deployments, # type: ignore messages=messages, input=input, ) if ( self.routing_strategy == "cost-based-routing" and self.lowestcost_logger is not None ): deployment = await self.lowestcost_logger.async_get_available_deployments( model_group=model, healthy_deployments=healthy_deployments, # type: ignore messages=messages, input=input, ) elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check if we can do a RPM/TPM based weighted pick ################# rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) if rpm is not None: # use weight-random pick if rpms provided rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments] verbose_router_logger.debug(f"\nrpms {rpms}") total_rpm = sum(rpms) weights = [rpm / total_rpm for rpm in rpms] verbose_router_logger.debug(f"\n weights {weights}") # Perform weighted random pick selected_index = random.choices(range(len(rpms)), weights=weights)[0] verbose_router_logger.debug(f"\n selected index, {selected_index}") deployment = healthy_deployments[selected_index] verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" ) return deployment or deployment[0] ############## Check if we can do a RPM/TPM based weighted pick ################# tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) if tpm is not None: # use weight-random pick if rpms provided tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments] verbose_router_logger.debug(f"\ntpms {tpms}") total_tpm = sum(tpms) weights = [tpm / total_tpm for tpm in tpms] verbose_router_logger.debug(f"\n weights {weights}") # Perform weighted random pick selected_index = random.choices(range(len(tpms)), weights=weights)[0] verbose_router_logger.debug(f"\n selected index, {selected_index}") deployment = healthy_deployments[selected_index] verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" ) return deployment or deployment[0] ############## No RPM/TPM passed, we do a random pick ################# item = random.choice(healthy_deployments) return item or item[0] if deployment is None: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" ) raise ValueError( f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}" ) verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) return deployment def get_available_deployment( self, model: str, messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, request_kwargs: Optional[Dict] = None, ): """ Returns the deployment based on routing strategy """ # users need to explicitly call a specific deployment, by setting `specific_deployment = True` as completion()/embedding() kwarg # When this was no explicit we had several issues with fallbacks timing out model, healthy_deployments = self._common_checks_available_deployment( model=model, messages=messages, input=input, specific_deployment=specific_deployment, ) if isinstance(healthy_deployments, dict): return healthy_deployments # filter out the deployments currently cooling down deployments_to_remove = [] # cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"] cooldown_deployments = self._get_cooldown_deployments() verbose_router_logger.debug(f"cooldown deployments: {cooldown_deployments}") # Find deployments in model_list whose model_id is cooling down for deployment in healthy_deployments: deployment_id = deployment["model_info"]["id"] if deployment_id in cooldown_deployments: deployments_to_remove.append(deployment) # remove unhealthy deployments from healthy deployments for deployment in deployments_to_remove: healthy_deployments.remove(deployment) # filter pre-call checks if self.enable_pre_call_checks and messages is not None: healthy_deployments = self._pre_call_checks( model=model, healthy_deployments=healthy_deployments, messages=messages ) if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: deployment = self.leastbusy_logger.get_available_deployments( model_group=model, healthy_deployments=healthy_deployments # type: ignore ) elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check if we can do a RPM/TPM based weighted pick ################# rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) if rpm is not None: # use weight-random pick if rpms provided rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments] verbose_router_logger.debug(f"\nrpms {rpms}") total_rpm = sum(rpms) weights = [rpm / total_rpm for rpm in rpms] verbose_router_logger.debug(f"\n weights {weights}") # Perform weighted random pick selected_index = random.choices(range(len(rpms)), weights=weights)[0] verbose_router_logger.debug(f"\n selected index, {selected_index}") deployment = healthy_deployments[selected_index] verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" ) return deployment or deployment[0] ############## Check if we can do a RPM/TPM based weighted pick ################# tpm = healthy_deployments[0].get("litellm_params").get("tpm", None) if tpm is not None: # use weight-random pick if rpms provided tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments] verbose_router_logger.debug(f"\ntpms {tpms}") total_tpm = sum(tpms) weights = [tpm / total_tpm for tpm in tpms] verbose_router_logger.debug(f"\n weights {weights}") # Perform weighted random pick selected_index = random.choices(range(len(tpms)), weights=weights)[0] verbose_router_logger.debug(f"\n selected index, {selected_index}") deployment = healthy_deployments[selected_index] verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" ) return deployment or deployment[0] ############## No RPM/TPM passed, we do a random pick ################# item = random.choice(healthy_deployments) return item or item[0] elif ( self.routing_strategy == "latency-based-routing" and self.lowestlatency_logger is not None ): deployment = self.lowestlatency_logger.get_available_deployments( model_group=model, healthy_deployments=healthy_deployments, # type: ignore request_kwargs=request_kwargs, ) elif ( self.routing_strategy == "usage-based-routing" and self.lowesttpm_logger is not None ): deployment = self.lowesttpm_logger.get_available_deployments( model_group=model, healthy_deployments=healthy_deployments, # type: ignore messages=messages, input=input, ) elif ( self.routing_strategy == "usage-based-routing-v2" and self.lowesttpm_logger_v2 is not None ): deployment = self.lowesttpm_logger_v2.get_available_deployments( model_group=model, healthy_deployments=healthy_deployments, # type: ignore messages=messages, input=input, ) if deployment is None: verbose_router_logger.info( f"get_available_deployment for model: {model}, No deployment available" ) raise ValueError( f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}" ) verbose_router_logger.info( f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) return deployment def _track_deployment_metrics(self, deployment, response=None): try: litellm_params = deployment["litellm_params"] api_base = litellm_params.get("api_base", "") model = litellm_params.get("model", "") model_id = deployment.get("model_info", {}).get("id", None) if response is None: # update self.deployment_stats if model_id is not None: self._update_usage(model_id) # update in-memory cache for tracking if model_id in self.deployment_stats: # only update num_requests self.deployment_stats[model_id]["num_requests"] += 1 else: self.deployment_stats[model_id] = { "api_base": api_base, "model": model, "num_requests": 1, } else: # check response_ms and update num_successes if isinstance(response, dict): response_ms = response.get("_response_ms", 0) else: response_ms = 0 if model_id is not None: if model_id in self.deployment_stats: # check if avg_latency exists if "avg_latency" in self.deployment_stats[model_id]: # update avg_latency self.deployment_stats[model_id]["avg_latency"] = ( self.deployment_stats[model_id]["avg_latency"] + response_ms ) / self.deployment_stats[model_id]["num_successes"] else: self.deployment_stats[model_id]["avg_latency"] = response_ms # check if num_successes exists if "num_successes" in self.deployment_stats[model_id]: self.deployment_stats[model_id]["num_successes"] += 1 else: self.deployment_stats[model_id]["num_successes"] = 1 else: self.deployment_stats[model_id] = { "api_base": api_base, "model": model, "num_successes": 1, "avg_latency": response_ms, } if self.set_verbose == True and self.debug_level == "DEBUG": from pprint import pformat # Assuming self.deployment_stats is your dictionary formatted_stats = pformat(self.deployment_stats) # Assuming verbose_router_logger is your logger verbose_router_logger.info( "self.deployment_stats: \n%s", formatted_stats ) except Exception as e: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") def get_num_retries_from_retry_policy( self, exception: Exception, model_group: Optional[str] = None ): """ BadRequestErrorRetries: Optional[int] = None AuthenticationErrorRetries: Optional[int] = None TimeoutErrorRetries: Optional[int] = None RateLimitErrorRetries: Optional[int] = None ContentPolicyViolationErrorRetries: Optional[int] = None """ # if we can find the exception then in the retry policy -> return the number of retries retry_policy = self.retry_policy if ( self.model_group_retry_policy is not None and model_group is not None and model_group in self.model_group_retry_policy ): retry_policy = self.model_group_retry_policy.get(model_group, None) if retry_policy is None: return None if isinstance(retry_policy, dict): retry_policy = RetryPolicy(**retry_policy) if ( isinstance(exception, litellm.BadRequestError) and retry_policy.BadRequestErrorRetries is not None ): return retry_policy.BadRequestErrorRetries if ( isinstance(exception, litellm.AuthenticationError) and retry_policy.AuthenticationErrorRetries is not None ): return retry_policy.AuthenticationErrorRetries if ( isinstance(exception, litellm.Timeout) and retry_policy.TimeoutErrorRetries is not None ): return retry_policy.TimeoutErrorRetries if ( isinstance(exception, litellm.RateLimitError) and retry_policy.RateLimitErrorRetries is not None ): return retry_policy.RateLimitErrorRetries if ( isinstance(exception, litellm.ContentPolicyViolationError) and retry_policy.ContentPolicyViolationErrorRetries is not None ): return retry_policy.ContentPolicyViolationErrorRetries def get_allowed_fails_from_policy(self, exception: Exception): """ BadRequestErrorRetries: Optional[int] = None AuthenticationErrorRetries: Optional[int] = None TimeoutErrorRetries: Optional[int] = None RateLimitErrorRetries: Optional[int] = None ContentPolicyViolationErrorRetries: Optional[int] = None """ # if we can find the exception then in the retry policy -> return the number of retries allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy if allowed_fails_policy is None: return None if ( isinstance(exception, litellm.BadRequestError) and allowed_fails_policy.BadRequestErrorAllowedFails is not None ): return allowed_fails_policy.BadRequestErrorAllowedFails if ( isinstance(exception, litellm.AuthenticationError) and allowed_fails_policy.AuthenticationErrorAllowedFails is not None ): return allowed_fails_policy.AuthenticationErrorAllowedFails if ( isinstance(exception, litellm.Timeout) and allowed_fails_policy.TimeoutErrorAllowedFails is not None ): return allowed_fails_policy.TimeoutErrorAllowedFails if ( isinstance(exception, litellm.RateLimitError) and allowed_fails_policy.RateLimitErrorAllowedFails is not None ): return allowed_fails_policy.RateLimitErrorAllowedFails if ( isinstance(exception, litellm.ContentPolicyViolationError) and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None ): return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails def _initialize_alerting(self): from litellm.integrations.slack_alerting import SlackAlerting router_alerting_config: AlertingConfig = self.alerting_config _slack_alerting_logger = SlackAlerting( alerting_threshold=router_alerting_config.alerting_threshold, alerting=["slack"], default_webhook_url=router_alerting_config.webhook_url, ) self.slack_alerting_logger = _slack_alerting_logger litellm.callbacks.append(_slack_alerting_logger) litellm.success_callback.append( _slack_alerting_logger.response_taking_too_long_callback ) print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa def send_deployment_cooldown_alert( self, deployment_id: str, exception_status: Union[str, int], cooldown_time: float, ): try: from litellm.proxy.proxy_server import proxy_logging_obj # trigger slack alert saying deployment is in cooldown if ( proxy_logging_obj is not None and proxy_logging_obj.alerting is not None and "slack" in proxy_logging_obj.alerting ): _deployment = self.get_deployment(model_id=deployment_id) if _deployment is None: return _litellm_params = _deployment["litellm_params"] temp_litellm_params = copy.deepcopy(_litellm_params) temp_litellm_params = dict(temp_litellm_params) _model_name = _deployment.get("model_name", None) _api_base = litellm.get_api_base( model=_model_name, optional_params=temp_litellm_params ) # asyncio.create_task( # proxy_logging_obj.slack_alerting_instance.send_alert( # message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns", # alert_type="cooldown_deployment", # level="Low", # ) # ) except Exception as e: pass def flush_cache(self): litellm.cache = None self.cache.flush_cache() def reset(self): ## clean up on close litellm.success_callback = [] litellm.__async_success_callback = [] litellm.failure_callback = [] litellm._async_failure_callback = [] self.retry_policy = None self.flush_cache()