From 56e9047818df2b99328b915caccfffda6cb6865d Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 29 Oct 2024 22:05:41 -0700 Subject: [PATCH] Litellm router max depth (#6501) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(router.py): add check for max fallback depth Prevent infinite loop for fallbacks Closes https://github.com/BerriAI/litellm/issues/6498 * test: update test * (fix) Prometheus - Log Postgres DB latency, status on prometheus (#6484) * fix logging DB fails on prometheus * unit testing log to otel wrapper * unit testing for service logger + prometheus * use LATENCY buckets for service logging * fix service logging * docs clarify vertex vs gemini * (router_strategy/) ensure all async functions use async cache methods (#6489) * fix router strat * use async set / get cache in router_strategy * add coverage for router strategy * fix imports * fix batch_get_cache * use async methods for least busy * fix least busy use async methods * fix test_dual_cache_increment * test async_get_available_deployment when routing_strategy="least-busy" * (fix) proxy - fix when `STORE_MODEL_IN_DB` should be set (#6492) * set store_model_in_db at the top * correctly use store_model_in_db global * (fix) `PrometheusServicesLogger` `_get_metric` should return metric in Registry (#6486) * fix logging DB fails on prometheus * unit testing log to otel wrapper * unit testing for service logger + prometheus * use LATENCY buckets for service logging * fix service logging * fix _get_metric in prom services logger * add clear doc string * unit testing for prom service logger * bump: version 1.51.0 → 1.51.1 * Add `azure/gpt-4o-mini-2024-07-18` to model_prices_and_context_window.json (#6477) * Update utils.py (#6468) Fixed missing keys * (perf) Litellm redis router fix - ~100ms improvement (#6483) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * refactor: pass parent_otel_span for redis caching calls in router allows for more observability into what calls are causing latency issues * test: update tests with new params * refactor: ensure e2e otel tracing for router * refactor(router.py): add more otel tracing acrosss router catch all latency issues for router requests * fix: fix linting error * fix(router.py): fix linting error * fix: fix test * test: fix tests * fix(dual_cache.py): pass ttl to redis cache * fix: fix param * perf(cooldown_cache.py): improve cooldown cache, to store cache results in memory for 5s, prevents redis call from being made on each request reduces 100ms latency per call with caching enabled on router * fix: fix test * fix(cooldown_cache.py): handle if a result is None * fix(cooldown_cache.py): add debug statements * refactor(dual_cache.py): move to using an in-memory check for batch get cache, to prevent redis from being hit for every call * fix(cooldown_cache.py): fix linting erropr * build: merge main --------- Co-authored-by: Ishaan Jaff Co-authored-by: Xingyao Wang Co-authored-by: vibhanshu-ob <115142120+vibhanshu-ob@users.noreply.github.com> --- litellm/__init__.py | 3 +- litellm/constants.py | 1 + litellm/main.py | 104 +-------- litellm/proxy/_new_secret_config.yaml | 4 + litellm/proxy/auth/user_api_key_auth.py | 7 +- .../example_config_yaml/custom_handler.py | 6 +- litellm/router.py | 202 +++++++----------- .../router_utils/fallback_event_handlers.py | 11 +- litellm/types/utils.py | 3 + .../test_router_fallback_handlers.py | 12 +- tests/local_testing/test_router_fallbacks.py | 47 +++- 11 files changed, 165 insertions(+), 235 deletions(-) create mode 100644 litellm/constants.py diff --git a/litellm/__init__.py b/litellm/__init__.py index a42a8f90d..b7a0bad8e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -17,7 +17,7 @@ from litellm._logging import ( _turn_on_json, log_level, ) - +from litellm.constants import ROUTER_MAX_FALLBACKS from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( KeyManagementSystem, @@ -284,6 +284,7 @@ request_timeout: float = 6000 # time in seconds module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) module_level_client = HTTPHandler(timeout=request_timeout) num_retries: Optional[int] = None # per model endpoint +max_fallbacks: Optional[int] = None default_fallbacks: Optional[List] = None fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None diff --git a/litellm/constants.py b/litellm/constants.py new file mode 100644 index 000000000..8d27cf564 --- /dev/null +++ b/litellm/constants.py @@ -0,0 +1 @@ +ROUTER_MAX_FALLBACKS = 5 diff --git a/litellm/main.py b/litellm/main.py index 30ff47e88..34e4ae5bb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3236,62 +3236,10 @@ def embedding( # noqa: PLR0915 "encoding_format", ] litellm_params = [ - "metadata", "aembedding", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "proxy_server_request", - "model_info", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "region_name", - "allowed_model_region", - "model_config", - "cooldown_time", - "tags", - "azure_ad_token_provider", - "tenant_id", - "client_id", - "client_secret", "extra_headers", - ] + ] + all_litellm_params + default_params = openai_params + litellm_params non_default_params = { k: v for k, v in kwargs.items() if k not in default_params @@ -4489,53 +4437,7 @@ def image_generation( # noqa: PLR0915 "size", "style", ] - litellm_params = [ - "metadata", - "aimg_generation", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "hf_model_name", - "proxy_server_request", - "model_info", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "region_name", - "allowed_model_region", - "model_config", - ] + litellm_params = all_litellm_params default_params = openai_params + litellm_params non_default_params = { k: v for k, v in kwargs.items() if k not in default_params diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 69a1119cc..ca198d7a3 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -10,6 +10,10 @@ model_list: output_cost_per_token: 0.000015 # 15$/M api_base: "https://exampleopenaiendpoint-production.up.railway.app" api_key: my-fake-key + - model_name: my-custom-model + litellm_params: + model: my-custom-llm/my-custom-model + api_key: my-fake-key litellm_settings: fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index a8cc9193e..f6c3de22c 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -152,7 +152,7 @@ def _is_api_route_allowed( _user_role = _get_user_role(user_obj=user_obj) if valid_token is None: - raise Exception("Invalid proxy server token passed") + raise Exception("Invalid proxy server token passed. valid_token=None.") if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin RouteChecks.non_proxy_admin_allowed_routes_check( @@ -769,6 +769,11 @@ async def user_api_key_auth( # noqa: PLR0915 ) except Exception: + verbose_logger.info( + "litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format( + api_key + ) + ) valid_token = None user_obj: Optional[LiteLLM_UserTable] = None diff --git a/litellm/proxy/example_config_yaml/custom_handler.py b/litellm/proxy/example_config_yaml/custom_handler.py index 56943c34d..fdde975d6 100644 --- a/litellm/proxy/example_config_yaml/custom_handler.py +++ b/litellm/proxy/example_config_yaml/custom_handler.py @@ -1,5 +1,9 @@ +import time +from typing import Any, Optional + import litellm -from litellm import CustomLLM, completion, get_llm_provider +from litellm import CustomLLM, ImageObject, ImageResponse, completion, get_llm_provider +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler class MyCustomLLM(CustomLLM): diff --git a/litellm/router.py b/litellm/router.py index e53c1a8a9..e60f05d84 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -189,6 +189,9 @@ class Router: default_priority: Optional[int] = None, ## RELIABILITY ## num_retries: Optional[int] = None, + max_fallbacks: Optional[ + int + ] = None, # max fallbacks to try before exiting the call. Defaults to 5. timeout: Optional[float] = None, default_litellm_params: Optional[ dict @@ -410,6 +413,13 @@ class Router: else: self.num_retries = openai.DEFAULT_MAX_RETRIES + if max_fallbacks is not None: + self.max_fallbacks = max_fallbacks + elif litellm.max_fallbacks is not None: + self.max_fallbacks = litellm.max_fallbacks + else: + self.max_fallbacks = litellm.ROUTER_MAX_FALLBACKS + self.timeout = timeout or litellm.request_timeout self.retry_after = retry_after @@ -2672,8 +2682,19 @@ class Router: if original_model_group is None: raise e + input_kwargs = { + "litellm_router": self, + "original_exception": original_exception, + **kwargs, + } + + if "max_fallbacks" not in input_kwargs: + input_kwargs["max_fallbacks"] = self.max_fallbacks + if "fallback_depth" not in input_kwargs: + input_kwargs["fallback_depth"] = 0 + try: - verbose_router_logger.debug("Trying to fallback b/w models") + verbose_router_logger.info("Trying to fallback b/w models") if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: fallback_model_group: Optional[List[str]] = ( @@ -2685,13 +2706,16 @@ class Router: if fallback_model_group is None: raise original_exception + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) + response = await run_async_fallback( *args, - litellm_router=self, - fallback_model_group=fallback_model_group, - original_model_group=original_model_group, - original_exception=original_exception, - **kwargs, + **input_kwargs, ) return response @@ -2718,13 +2742,16 @@ class Router: if fallback_model_group is None: raise original_exception + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) + response = await run_async_fallback( *args, - litellm_router=self, - fallback_model_group=fallback_model_group, - original_model_group=original_model_group, - original_exception=original_exception, - **kwargs, + **input_kwargs, ) return response else: @@ -2767,13 +2794,16 @@ class Router: original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore raise original_exception + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) + response = await run_async_fallback( *args, - litellm_router=self, - fallback_model_group=fallback_model_group, - original_model_group=original_model_group, - original_exception=original_exception, - **kwargs, + **input_kwargs, ) return response except Exception as new_exception: @@ -2982,7 +3012,9 @@ class Router: Handler for making a call to the .completion()/.embeddings()/etc. functions. """ model_group = kwargs.get("model") - response = await original_function(*args, **kwargs) + response = original_function(*args, **kwargs) + if inspect.iscoroutinefunction(response) or inspect.isawaitable(response): + response = await response ## PROCESS RESPONSE HEADERS await self.set_response_headers(response=response, model_group=model_group) @@ -3080,120 +3112,38 @@ class Router: def function_with_fallbacks(self, *args, **kwargs): """ - Try calling the function_with_retries - If it fails after num_retries, fall back to another model group + Sync wrapper for async_function_with_fallbacks + + Wrapped to reduce code duplication and prevent bugs. """ - model_group: Optional[str] = kwargs.get("model") - fallbacks = kwargs.get("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.get( - "context_window_fallbacks", self.context_window_fallbacks - ) - content_policy_fallbacks = kwargs.get( - "content_policy_fallbacks", self.content_policy_fallbacks - ) + import threading + from concurrent.futures import ThreadPoolExecutor + + def run_in_new_loop(): + """Run the coroutine in a new event loop within this thread.""" + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete( + self.async_function_with_fallbacks(*args, **kwargs) + ) + finally: + new_loop.close() + asyncio.set_event_loop(None) try: - self._handle_mock_testing_fallbacks( - kwargs=kwargs, - model_group=model_group, - fallbacks=fallbacks, - context_window_fallbacks=context_window_fallbacks, - content_policy_fallbacks=content_policy_fallbacks, - ) - response = self.function_with_retries(*args, **kwargs) - return response - except Exception as e: - original_exception = e - original_model_group: Optional[str] = kwargs.get("model") - verbose_router_logger.debug(f"An exception occurs {original_exception}") + # First, try to get the current event loop + loop = asyncio.get_running_loop() - if original_model_group is None: - raise e + # If we're already in an event loop, run in a separate thread + # to avoid nested event loop issues + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + return future.result() - try: - verbose_router_logger.debug( - f"Trying to fallback b/w models. Initial model group: {model_group}" - ) - if ( - isinstance(e, litellm.ContextWindowExceededError) - and context_window_fallbacks is not None - ): - fallback_model_group = None - - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) - ) - - if fallback_model_group is None: - raise original_exception - - return run_sync_fallback( - *args, - litellm_router=self, - fallback_model_group=fallback_model_group, - original_model_group=original_model_group, - original_exception=original_exception, - **kwargs, - ) - elif ( - isinstance(e, litellm.ContentPolicyViolationError) - and content_policy_fallbacks is not None - ): - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) - ) - - if fallback_model_group is None: - raise original_exception - - return run_sync_fallback( - *args, - litellm_router=self, - fallback_model_group=fallback_model_group, - original_model_group=original_model_group, - original_exception=original_exception, - **kwargs, - ) - elif fallbacks is not None: - verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") - fallback_model_group = None - generic_fallback_idx: Optional[int] = None - for idx, item in enumerate(fallbacks): - if isinstance(item, dict): - if list(item.keys())[0] == model_group: - fallback_model_group = item[model_group] - break - elif list(item.keys())[0] == "*": - generic_fallback_idx = idx - elif isinstance(item, str): - fallback_model_group = [fallbacks.pop(idx)] - ## if none, check for generic fallback - if ( - fallback_model_group is None - and generic_fallback_idx is not None - ): - fallback_model_group = fallbacks[generic_fallback_idx]["*"] - - if fallback_model_group is None: - raise original_exception - - return run_sync_fallback( - *args, - litellm_router=self, - fallback_model_group=fallback_model_group, - original_model_group=original_model_group, - original_exception=original_exception, - **kwargs, - ) - except Exception as e: - raise e - raise original_exception + except RuntimeError: + # No running event loop, we can safely run in this thread + return run_in_new_loop() def _get_fallback_model_group_from_fallbacks( self, diff --git a/litellm/router_utils/fallback_event_handlers.py b/litellm/router_utils/fallback_event_handlers.py index 60cc150a2..5d027e597 100644 --- a/litellm/router_utils/fallback_event_handlers.py +++ b/litellm/router_utils/fallback_event_handlers.py @@ -14,11 +14,13 @@ else: async def run_async_fallback( - litellm_router: LitellmRouter, *args: Tuple[Any], + litellm_router: LitellmRouter, fallback_model_group: List[str], original_model_group: str, original_exception: Exception, + max_fallbacks: int, + fallback_depth: int, **kwargs, ) -> Any: """ @@ -41,6 +43,11 @@ async def run_async_fallback( Raises: The most recent exception if all fallback model groups fail. """ + + ### BASE CASE ### MAX FALLBACK DEPTH REACHED + if fallback_depth >= max_fallbacks: + raise original_exception + error_from_fallbacks = original_exception for mg in fallback_model_group: if mg == original_model_group: @@ -53,6 +60,8 @@ async def run_async_fallback( kwargs.setdefault("metadata", {}).update( {"model_group": mg} ) # update model_group used, if fallbacks are done + kwargs["fallback_depth"] = fallback_depth + 1 + kwargs["max_fallbacks"] = max_fallbacks response = await litellm_router.async_function_with_fallbacks( *args, **kwargs ) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 6658eb330..09f70e9b3 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1292,6 +1292,7 @@ all_litellm_params = [ "metadata", "tags", "acompletion", + "aimg_generation", "atext_completion", "text_completion", "caching", @@ -1357,6 +1358,8 @@ all_litellm_params = [ "ensure_alternating_roles", "assistant_continue_message", "user_continue_message", + "fallback_depth", + "max_fallbacks", ] diff --git a/tests/local_testing/test_router_fallback_handlers.py b/tests/local_testing/test_router_fallback_handlers.py index 18fb6ba3f..bd021cd3f 100644 --- a/tests/local_testing/test_router_fallback_handlers.py +++ b/tests/local_testing/test_router_fallback_handlers.py @@ -88,12 +88,14 @@ async def test_run_async_fallback(original_function): request_kwargs["messages"] = [{"role": "user", "content": "Hello, world!"}] result = await run_async_fallback( - router, + litellm_router=router, original_function=original_function, num_retries=1, fallback_model_group=fallback_model_group, original_model_group=original_model_group, original_exception=original_exception, + max_fallbacks=5, + fallback_depth=0, **request_kwargs ) @@ -264,13 +266,15 @@ async def test_failed_fallbacks_raise_most_recent_exception(original_function): with pytest.raises(litellm.exceptions.RateLimitError): await run_async_fallback( - router, + litellm_router=router, original_function=original_function, num_retries=1, fallback_model_group=fallback_model_group, original_model_group=original_model_group, original_exception=original_exception, mock_response="litellm.RateLimitError", + max_fallbacks=5, + fallback_depth=0, **request_kwargs ) @@ -332,12 +336,14 @@ async def test_multiple_fallbacks(original_function): request_kwargs["messages"] = [{"role": "user", "content": "Hello, world!"}] result = await run_async_fallback( - router_2, + litellm_router=router_2, original_function=original_function, num_retries=1, fallback_model_group=fallback_model_group, original_model_group=original_model_group, original_exception=original_exception, + max_fallbacks=5, + fallback_depth=0, **request_kwargs ) diff --git a/tests/local_testing/test_router_fallbacks.py b/tests/local_testing/test_router_fallbacks.py index 32bf0f92f..96983003a 100644 --- a/tests/local_testing/test_router_fallbacks.py +++ b/tests/local_testing/test_router_fallbacks.py @@ -1045,7 +1045,7 @@ async def test_default_model_fallbacks(sync_mode, litellm_module_fallbacks): }, ], default_fallbacks=( - ["my-good-model"] if litellm_module_fallbacks == False else None + ["my-good-model"] if litellm_module_fallbacks is False else None ), ) @@ -1398,3 +1398,48 @@ def test_router_fallbacks_with_custom_model_costs(): assert model_info["input_cost_per_token"] == 30 assert model_info["output_cost_per_token"] == 60 + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode): + """ + Tests to ensure there is not an infinite fallback loop when there is a default fallback and model specific fallback. + """ + router = Router( + model_list=[ + { + "model_name": "bad-model", + "litellm_params": { + "model": "openai/my-bad-model", + "api_key": "my-bad-api-key", + }, + }, + { + "model_name": "my-bad-model-2", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad-key", + }, + }, + ], + fallbacks=[{"bad-model": ["my-bad-model-2"]}], + default_fallbacks=["bad-model"], + ) + + with pytest.raises(Exception) as exc_info: + if sync_mode: + resp = router.completion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + + print(f"resp: {resp}") + else: + await router.acompletion( + model="bad-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + assert isinstance( + exc_info.value, litellm.AuthenticationError + ), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"