From 22e26fcc4b156d4e89730e2363743be97cdfa33d Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 23 Jan 2024 08:03:29 -0800 Subject: [PATCH] (fix) revert router.py to stable version --- litellm/router.py | 99 +++++++++++------------------------------------ 1 file changed, 23 insertions(+), 76 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index f506ff832f..dd6303a948 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -96,13 +96,10 @@ class Router: set_verbose: bool = False, debug_level: Literal["DEBUG", "INFO"] = "INFO", fallbacks: List = [], + allowed_fails: Optional[int] = None, context_window_fallbacks: List = [], model_group_alias: Optional[dict] = {}, retry_after: int = 0, # min time to wait before retrying a failed request - allowed_fails: Optional[ - int - ] = None, # Number of times a deployment can failbefore being added to cooldown - cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure routing_strategy: Literal[ "simple-shuffle", "least-busy", @@ -111,36 +108,6 @@ class Router: ] = "simple-shuffle", routing_strategy_args: dict = {}, # just for latency-based routing ) -> None: - """ - Initialize the Router class with the given parameters for caching, reliability, and routing strategy. - - Args: - model_list (Optional[list]): List of models to be used. Defaults to None. - redis_url (Optional[str]): URL of the Redis server. Defaults to None. - redis_host (Optional[str]): Hostname of the Redis server. Defaults to None. - redis_port (Optional[int]): Port of the Redis server. Defaults to None. - redis_password (Optional[str]): Password of the Redis server. Defaults to None. - cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False. - cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}. - caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None. - client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600. - num_retries (int): Number of retries for failed requests. Defaults to 0. - timeout (Optional[float]): Timeout for requests. Defaults to None. - default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}. - set_verbose (bool): Flag to set verbose mode. Defaults to False. - debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO". - fallbacks (List): List of fallback options. Defaults to []. - context_window_fallbacks (List): List of context window fallback options. Defaults to []. - model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}. - retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0. - allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None. - cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1. - routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]): Routing strategy. Defaults to "simple-shuffle". - routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}. - - Returns: - Router: An instance of the litellm.Router class. - """ self.set_verbose = set_verbose if self.set_verbose: if debug_level == "INFO": @@ -196,7 +163,6 @@ class Router: self.deployment_latency_map[m["litellm_params"]["model"]] = 0 self.allowed_fails = allowed_fails or litellm.allowed_fails - self.cooldown_time = cooldown_time or 1 self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown @@ -336,11 +302,11 @@ class Router: response = litellm.completion( **{ + **data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs, - **data, } ) verbose_router_logger.info( @@ -359,6 +325,7 @@ class Router: kwargs["messages"] = messages kwargs["original_function"] = self._acompletion kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) response = await self.async_function_with_fallbacks(**kwargs) @@ -406,15 +373,16 @@ class Router: else: model_client = potential_model_client self.total_calls[model_name] += 1 - final_data = { - "messages": messages, - "caching": self.cache_responses, - "client": model_client, - "timeout": self.timeout, - **kwargs, - **data, - } - response = await litellm.acompletion(**final_data) + response = await litellm.acompletion( + **{ + **data, + "messages": messages, + "caching": self.cache_responses, + "client": model_client, + "timeout": self.timeout, + **kwargs, + } + ) self.success_calls[model_name] += 1 verbose_router_logger.info( f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" @@ -877,9 +845,6 @@ class Router: """ try: kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done response = await self.async_function_with_retries( *args, **kwargs ) @@ -908,10 +873,8 @@ class Router: f"Falling back to model_group = {mg}" ) kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done - response = await self.async_function_with_fallbacks( + kwargs["metadata"]["model_group"] = mg + response = await self.async_function_with_retries( *args, **kwargs ) return response @@ -1076,9 +1039,6 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: @@ -1102,9 +1062,6 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg - kwargs.setdefault("metadata", {}).update( - {"model_group": mg} - ) # update model_group used, if fallbacks are done response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: @@ -1290,7 +1247,6 @@ class Router: verbose_router_logger.debug( f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}" ) - cooldown_time = self.cooldown_time or 1 if updated_fails > self.allowed_fails: # get the current cooldown list for that minute cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls @@ -1304,19 +1260,13 @@ class Router: else: cached_value = cached_value + [deployment] # save updated value - self.cache.set_cache( - value=cached_value, key=cooldown_key, ttl=cooldown_time - ) + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) except: cached_value = [deployment] # save updated value - self.cache.set_cache( - value=cached_value, key=cooldown_key, ttl=cooldown_time - ) + self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1) else: - self.failed_calls.set_cache( - key=deployment, value=updated_fails, ttl=cooldown_time - ) + self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=1) def _get_cooldown_deployments(self): """ @@ -1401,16 +1351,13 @@ class Router: ): stream_timeout_env_name = stream_timeout.replace("os.environ/", "") stream_timeout = litellm.get_secret(stream_timeout_env_name) + litellm_params["stream_timeout"] = stream_timeout max_retries = litellm_params.pop("max_retries", 2) - if isinstance(max_retries, str): - if max_retries.startswith("os.environ/"): - max_retries_env_name = max_retries.replace("os.environ/", "") - max_retries = litellm.get_secret(max_retries_env_name) - max_retries = int(max_retries) - litellm_params[ - "max_retries" - ] = max_retries # do this for testing purposes + if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): + max_retries_env_name = max_retries.replace("os.environ/", "") + max_retries = litellm.get_secret(max_retries_env_name) + litellm_params["max_retries"] = max_retries if "azure" in model_name: if api_base is None: