From 59ba1560e5cee5a2502947129083f47c0cbf4d7e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 25 Nov 2023 19:34:20 -0800 Subject: [PATCH] fix(router.py): fix fallbacks --- docs/my-website/docs/simple_proxy.md | 21 ++++++++++++++++++ litellm/router.py | 33 ++++++++++++++++++++-------- litellm/utils.py | 1 + 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/docs/my-website/docs/simple_proxy.md b/docs/my-website/docs/simple_proxy.md index d5f98a374c..b55f3cfafc 100644 --- a/docs/my-website/docs/simple_proxy.md +++ b/docs/my-website/docs/simple_proxy.md @@ -839,6 +839,7 @@ If the error is a context window exceeded error, fall back to a larger model gro [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py) +**Set via config** ```yaml model_list: - model_name: zephyr-beta @@ -870,6 +871,26 @@ litellm_settings: allowed_fails: 3 # cooldown model if it fails > 1 call in a minute. ``` +**Set dynamically** + +```bash +curl --location 'http://0.0.0.0:8000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "zephyr-beta", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + "fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}], + "context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}], + "num_retries": 2, + "request_timeout": 10 + } +' +``` ### Config for Embedding Models - xorbitsai/inference diff --git a/litellm/router.py b/litellm/router.py index c17fcf8774..471c4f857e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -144,12 +144,13 @@ class Router: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._completion - kwargs["num_retries"] = self.num_retries + timeout = kwargs.get("request_timeout", self.timeout) + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs.setdefault("metadata", {}).update({"model_group": model}) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: # Submit the function to the executor with a timeout future = executor.submit(self.function_with_fallbacks, **kwargs) - response = future.result(timeout=self.timeout) # type: ignore + response = future.result(timeout=timeout) # type: ignore return response except Exception as e: @@ -164,6 +165,7 @@ class Router: try: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, messages=messages) + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) data = deployment["litellm_params"] for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params @@ -182,9 +184,10 @@ class Router: kwargs["model"] = model kwargs["messages"] = messages kwargs["original_function"] = self._acompletion - kwargs["num_retries"] = self.num_retries + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) - response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=self.timeout) + response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout) return response except Exception as e: @@ -196,7 +199,9 @@ class Router: messages: List[Dict[str, str]], **kwargs): try: + self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}") deployment = self.get_available_deployment(model=model, messages=messages) + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) data = deployment["litellm_params"] for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params @@ -244,7 +249,7 @@ class Router: **kwargs) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, input=input) - + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) data = deployment["litellm_params"] for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params @@ -259,7 +264,7 @@ class Router: **kwargs) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, input=input) - + kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) data = deployment["litellm_params"] for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params @@ -315,10 +320,11 @@ class Router: """ try: kwargs["model"] = mg + kwargs["metadata"]["model_group"] = mg response = await self.async_function_with_retries(*args, **kwargs) return response except Exception as e: - pass + raise e except Exception as e: self.print_verbose(f"An exception occurred - {str(e)}") traceback.print_exc() @@ -328,12 +334,14 @@ class Router: self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") backoff_factor = 1 original_function = kwargs.pop("original_function") + self.print_verbose(f"async function w/ retries: original_function - {original_function}") num_retries = kwargs.pop("num_retries") try: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) return response except Exception as e: + original_exception = e for current_attempt in range(num_retries): self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") try: @@ -359,7 +367,7 @@ class Router: pass else: raise e - raise e + raise original_exception def function_with_fallbacks(self, *args, **kwargs): """ @@ -499,10 +507,13 @@ class Router: try: model_name = kwargs.get('model', None) # i.e. gpt35turbo custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure + metadata = kwargs.get("litellm_params", {}).get('metadata', None) + if metadata: + deployment = metadata.get("deployment", None) + self._set_cooldown_deployments(deployment) if custom_llm_provider: model_name = f"{custom_llm_provider}/{model_name}" - self._set_cooldown_deployments(model_name) except Exception as e: raise e @@ -731,6 +742,10 @@ class Router: ### get all deployments ### filter out the deployments currently cooling down healthy_deployments = [m for m in self.model_list if m["model_name"] == model] + if len(healthy_deployments) == 0: + # check if the user sent in a deployment name instead + healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model] + self.print_verbose(f"initial list of deployments: {healthy_deployments}") deployments_to_remove = [] cooldown_deployments = self._get_cooldown_deployments() self.print_verbose(f"cooldown deployments: {cooldown_deployments}") diff --git a/litellm/utils.py b/litellm/utils.py index 04b8c4c5fe..45e8f59d75 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -680,6 +680,7 @@ class Logging: print_verbose( f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" ) + print_verbose(f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}") if self.logger_fn and callable(self.logger_fn): try: self.logger_fn(