From 8ac03e492f0e1aaae3170ce464a493d2e5bd571e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 16:06:39 -0800 Subject: [PATCH] fix(router.py): enable fallbacks for sync completions --- litellm/router.py | 588 ++++++++++++++++--------- litellm/tests/test_router_fallbacks.py | 76 ++++ 2 files changed, 457 insertions(+), 207 deletions(-) create mode 100644 litellm/tests/test_router_fallbacks.py diff --git a/litellm/router.py b/litellm/router.py index 72533e06c2..00e4cc7b2e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -9,28 +9,47 @@ from datetime import datetime from typing import Dict, List, Optional, Union, Literal -import random, threading, time +import random, threading, time, traceback import litellm, openai from litellm.caching import RedisCache, InMemoryCache, DualCache import logging, asyncio -import inspect +import inspect, concurrent from openai import AsyncOpenAI class Router: """ Example usage: + ```python from litellm import Router - model_list = [{ - "model_name": "gpt-3.5-turbo", # model alias + model_list = [ + { + "model_name": "azure-gpt-3.5-turbo", # model alias "litellm_params": { # params for litellm completion/embedding call - "model": "azure/", + "model": "azure/", "api_key": , "api_version": , "api_base": }, - }] + }, + { + "model_name": "azure-gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/", + "api_key": , + "api_version": , + "api_base": + }, + }, + { + "model_name": "openai-gpt-3.5-turbo", # model alias + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": , + }, + ] - router = Router(model_list=model_list) + router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) + ``` """ model_names: List = [] cache_responses: bool = False @@ -48,6 +67,8 @@ class Router: timeout: float = 600, default_litellm_params = {}, # default params for Router.chat.completion.create set_verbose: bool = False, + fallbacks: List = [], + context_window_fallbacks: List = [], routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None: if model_list: @@ -60,12 +81,19 @@ class Router: self.num_retries = num_retries self.set_verbose = set_verbose + self.timeout = timeout + self.routing_strategy = routing_strategy + self.fallbacks = fallbacks + self.context_window_fallbacks = context_window_fallbacks + + # make Router.chat.completions.create compatible for openai.chat.completions.create self.chat = litellm.Chat(params=default_litellm_params) + # default litellm args self.default_litellm_params = default_litellm_params self.default_litellm_params["timeout"] = timeout - self.routing_strategy = routing_strategy + ### HEALTH CHECK THREAD ### if self.routing_strategy == "least-busy": self._start_health_check_thread() @@ -99,192 +127,7 @@ class Router: else: litellm.failure_callback = [self.deployment_callback_on_failure] - def _start_health_check_thread(self): - """ - Starts a separate thread to perform health checks periodically. - """ - health_check_thread = threading.Thread(target=self._perform_health_checks, daemon=True) - health_check_thread.start() - def _perform_health_checks(self): - """ - Periodically performs health checks on the servers. - Updates the list of healthy servers accordingly. - """ - while True: - self.healthy_deployments = self._health_check() - # Adjust the time interval based on your needs - time.sleep(15) - - def _health_check(self): - """ - Performs a health check on the deployments - Returns the list of healthy deployments - """ - healthy_deployments = [] - for deployment in self.model_list: - litellm_args = deployment["litellm_params"] - try: - start_time = time.time() - litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond - end_time = time.time() - response_time = end_time - start_time - logging.debug(f"response_time: {response_time}") - healthy_deployments.append((deployment, response_time)) - healthy_deployments.sort(key=lambda x: x[1]) - except Exception as e: - pass - return healthy_deployments - - def weighted_shuffle_by_latency(self, items): - # Sort the items by latency - sorted_items = sorted(items, key=lambda x: x[1]) - # Get only the latencies - latencies = [i[1] for i in sorted_items] - # Calculate the sum of all latencies - total_latency = sum(latencies) - # Calculate the weight for each latency (lower latency = higher weight) - weights = [total_latency-latency for latency in latencies] - # Get a weighted random item - if sum(weights) == 0: - chosen_item = random.choice(sorted_items)[0] - else: - chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0] - return chosen_item - - def set_model_list(self, model_list: list): - self.model_list = model_list - self.model_names = [m["model_name"] for m in model_list] - - def get_model_names(self): - return self.model_names - - def print_verbose(self, print_statement): - if self.set_verbose: - print(f"LiteLLM.Router: {print_statement}") # noqa - - def get_available_deployment(self, - model: str, - messages: Optional[List[Dict[str, str]]] = None, - input: Optional[Union[str, List]] = None): - """ - Returns the deployment based on routing strategy - """ - ## get healthy deployments - ### get all deployments - ### filter out the deployments currently cooling down - healthy_deployments = [m for m in self.model_list if m["model_name"] == model] - deployments_to_remove = [] - cooldown_deployments = self._get_cooldown_deployments() - self.print_verbose(f"cooldown deployments: {cooldown_deployments}") - ### FIND UNHEALTHY DEPLOYMENTS - for deployment in healthy_deployments: - deployment_name = deployment["litellm_params"]["model"] - if deployment_name in cooldown_deployments: - deployments_to_remove.append(deployment) - ### FILTER OUT UNHEALTHY DEPLOYMENTS - for deployment in deployments_to_remove: - healthy_deployments.remove(deployment) - self.print_verbose(f"healthy deployments: {healthy_deployments}") - if litellm.model_alias_map and model in litellm.model_alias_map: - model = litellm.model_alias_map[ - model - ] # update the model to the actual value if an alias has been passed in - if self.routing_strategy == "least-busy": - if len(self.healthy_deployments) > 0: - for item in self.healthy_deployments: - if item[0]["model_name"] == model: # first one in queue will be the one with the most availability - return item[0] - else: - raise ValueError("No models available.") - elif self.routing_strategy == "simple-shuffle": - item = random.choice(healthy_deployments) - return item or item[0] - elif self.routing_strategy == "latency-based-routing": - returned_item = None - lowest_latency = float('inf') - ### shuffles with priority for lowest latency - # items_with_latencies = [('A', 10), ('B', 20), ('C', 30), ('D', 40)] - items_with_latencies = [] - for item in healthy_deployments: - items_with_latencies.append((item, self.deployment_latency_map[item["litellm_params"]["model"]])) - returned_item = self.weighted_shuffle_by_latency(items_with_latencies) - return returned_item - elif self.routing_strategy == "usage-based-routing": - return self.get_usage_based_available_deployment(model=model, messages=messages, input=input) - - raise ValueError("No models available.") - - def retry_if_rate_limit_error(self, exception): - return isinstance(exception, openai.RateLimitError) - - def retry_if_api_error(self, exception): - return isinstance(exception, openai.APIError) - - async def async_function_with_retries(self, *args, **kwargs): - # we'll backoff exponentially with each retry - backoff_factor = 1 - original_exception = kwargs.pop("original_exception") - original_function = kwargs.pop("original_function") - for current_attempt in range(self.num_retries): - try: - # if the function call is successful, no exception will be raised and we'll break out of the loop - response = await original_function(*args, **kwargs) - if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines - response = await response - return response - - except openai.RateLimitError as e: - # on RateLimitError we'll wait for an exponential time before trying again - await asyncio.sleep(backoff_factor) - - # increase backoff factor for next run - backoff_factor *= 2 - - except openai.APIError as e: - # on APIError we immediately retry without any wait, change this if necessary - pass - - except Exception as e: - # for any other exception types, don't retry - raise e - - def function_with_retries(self, *args, **kwargs): - # we'll backoff exponentially with each retry - self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}") - backoff_factor = 1 - original_function = kwargs.pop("original_function") - num_retries = kwargs.pop("num_retries") - try: - # if the function call is successful, no exception will be raised and we'll break out of the loop - response = original_function(*args, **kwargs) - return response - except Exception as e: - for current_attempt in range(num_retries): - num_retries -= 1 # decrement the number of retries - self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") - try: - # if the function call is successful, no exception will be raised and we'll break out of the loop - response = original_function(*args, **kwargs) - return response - - except openai.RateLimitError as e: - if num_retries > 0: - # on RateLimitError we'll wait for an exponential time before trying again - time.sleep(backoff_factor) - - # increase backoff factor for next run - backoff_factor *= 2 - else: - raise e - - except Exception as e: - # for any other exception types, immediately retry - if num_retries > 0: - pass - else: - raise e - ### COMPLETION + EMBEDDING FUNCTIONS def completion(self, @@ -300,7 +143,12 @@ class Router: kwargs["messages"] = messages kwargs["original_function"] = self._completion kwargs["num_retries"] = self.num_retries - return self.function_with_retries(**kwargs) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + # Submit the function to the executor with a timeout + future = executor.submit(self.function_with_fallbacks, **kwargs) + response = future.result(timeout=self.timeout) + + return response except Exception as e: raise e @@ -322,18 +170,38 @@ class Router: return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) except Exception as e: raise e + async def acompletion(self, - model: str, - messages: List[Dict[str, str]], - is_retry: Optional[bool] = False, - is_fallback: Optional[bool] = False, - **kwargs): + model: str, + messages: List[Dict[str, str]], + **kwargs): + try: + kwargs["model"] = model + kwargs["messages"] = messages + kwargs["original_function"] = self._completion + kwargs["num_retries"] = self.num_retries + + # Use asyncio.timeout to enforce the timeout + async with asyncio.timeout(self.timeout): + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _acompletion( + self, + model: str, + messages: List[Dict[str, str]], + **kwargs): try: deployment = self.get_available_deployment(model=model, messages=messages) data = deployment["litellm_params"] for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params data[k] = v + self.print_verbose(f"acompletion model: {data['model']}") + response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) return response except Exception as e: @@ -345,7 +213,7 @@ class Router: return await self.async_function_with_retries(**kwargs) else: raise e - + def text_completion(self, model: str, prompt: str, @@ -403,6 +271,190 @@ class Router: data[k] = v return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) + async def async_function_with_fallbacks(self, *args, **kwargs): + """ + Try calling the function_with_retries + If it fails after num_retries, fall back to another model group + """ + model_group = kwargs.get("model") + try: + response = await self.async_function_with_retries(*args, **kwargs) + self.print_verbose(f'Response: {response}') + return response + except Exception as e: + self.print_verbose(f"An exception occurs") + original_exception = e + try: + self.print_verbose(f"Trying to fallback b/w models") + if isinstance(e, litellm.ContextWindowExceededError): + for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + for mg in fallback_model_group: + """ + Iterate through the model groups and try calling that deployment + """ + try: + kwargs["model"] = mg + response = await self.async_function_with_retries(*args, **kwargs) + return response + except Exception as e: + pass + else: + self.print_verbose(f"inside model fallbacks: {self.fallbacks}") + for item in self.fallbacks: + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + for mg in fallback_model_group: + """ + Iterate through the model groups and try calling that deployment + """ + try: + kwargs["model"] = mg + response = await self.async_function_with_retries(*args, **kwargs) + return response + except Exception as e: + pass + except Exception as e: + self.print_verbose(f"An exception occurred - {str(e)}") + traceback.print_exc() + raise original_exception + + async def async_function_with_retries(self, *args, **kwargs): + self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") + backoff_factor = 1 + original_function = kwargs.pop("original_function") + num_retries = kwargs.pop("num_retries") + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = await original_function(*args, **kwargs) + return response + except Exception as e: + for current_attempt in range(num_retries): + num_retries -= 1 # decrement the number of retries + self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = await original_function(*args, **kwargs) + if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines + response = await response + return response + + except openai.RateLimitError as e: + if num_retries > 0: + # on RateLimitError we'll wait for an exponential time before trying again + await asyncio.sleep(backoff_factor) + + # increase backoff factor for next run + backoff_factor *= 2 + else: + raise e + + except Exception as e: + # for any other exception types, immediately retry + if num_retries > 0: + pass + else: + raise e + + def function_with_fallbacks(self, *args, **kwargs): + """ + Try calling the function_with_retries + If it fails after num_retries, fall back to another model group + """ + model_group = kwargs.get("model") + try: + response = self.function_with_retries(*args, **kwargs) + self.print_verbose(f'Response: {response}') + return response + except Exception as e: + self.print_verbose(f"An exception occurs") + original_exception = e + try: + self.print_verbose(f"Trying to fallback b/w models") + if isinstance(e, litellm.ContextWindowExceededError): + for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}] + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + for mg in fallback_model_group: + """ + Iterate through the model groups and try calling that deployment + """ + try: + kwargs["model"] = mg + response = self.function_with_retries(*args, **kwargs) + return response + except Exception as e: + pass + else: + self.print_verbose(f"inside model fallbacks: {self.fallbacks}") + for item in self.fallbacks: + if list(item.keys())[0] == model_group: + fallback_model_group = item[model_group] + break + for mg in fallback_model_group: + """ + Iterate through the model groups and try calling that deployment + """ + try: + kwargs["model"] = mg + response = self.function_with_retries(*args, **kwargs) + return response + except Exception as e: + pass + except Exception as e: + self.print_verbose(f"An exception occurred - {str(e)}") + traceback.print_exc() + raise original_exception + + + + def function_with_retries(self, *args, **kwargs): + """ + Try calling the model 3 times. Shuffle between available deployments. + """ + self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}") + backoff_factor = 1 + original_function = kwargs.pop("original_function") + num_retries = kwargs.pop("num_retries") + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = original_function(*args, **kwargs) + return response + except Exception as e: + self.print_verbose(f"num retries in function with retries: {num_retries}") + for current_attempt in range(num_retries): + num_retries -= 1 # decrement the number of retries + self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") + try: + # if the function call is successful, no exception will be raised and we'll break out of the loop + response = original_function(*args, **kwargs) + return response + + except openai.RateLimitError as e: + if num_retries > 0: + # on RateLimitError we'll wait for an exponential time before trying again + time.sleep(backoff_factor) + + # increase backoff factor for next run + backoff_factor *= 2 + else: + raise e + + except Exception as e: + # for any other exception types, immediately retry + if num_retries > 0: + pass + else: + raise e + if self.num_retries == 0: + raise e + + ### HELPER FUNCTIONS + def deployment_callback( self, kwargs, # kwargs to completion @@ -433,12 +485,15 @@ class Router: completion_response, # response from completion start_time, end_time # start/end time ): - model_name = kwargs.get('model', None) # i.e. gpt35turbo - custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure - if custom_llm_provider: - model_name = f"{custom_llm_provider}/{model_name}" - - self._set_cooldown_deployments(model_name) + try: + model_name = kwargs.get('model', None) # i.e. gpt35turbo + custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure + if custom_llm_provider: + model_name = f"{custom_llm_provider}/{model_name}" + + self._set_cooldown_deployments(model_name) + except Exception as e: + raise e def _set_cooldown_deployments(self, deployment: str): @@ -577,4 +632,123 @@ class Router: # Update usage # ------------ self.increment(tpm_key, total_tokens) - self.increment(rpm_key, 1) \ No newline at end of file + self.increment(rpm_key, 1) + + def _start_health_check_thread(self): + """ + Starts a separate thread to perform health checks periodically. + """ + health_check_thread = threading.Thread(target=self._perform_health_checks, daemon=True) + health_check_thread.start() + + def _perform_health_checks(self): + """ + Periodically performs health checks on the servers. + Updates the list of healthy servers accordingly. + """ + while True: + self.healthy_deployments = self._health_check() + # Adjust the time interval based on your needs + time.sleep(15) + + def _health_check(self): + """ + Performs a health check on the deployments + Returns the list of healthy deployments + """ + healthy_deployments = [] + for deployment in self.model_list: + litellm_args = deployment["litellm_params"] + try: + start_time = time.time() + litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond + end_time = time.time() + response_time = end_time - start_time + logging.debug(f"response_time: {response_time}") + healthy_deployments.append((deployment, response_time)) + healthy_deployments.sort(key=lambda x: x[1]) + except Exception as e: + pass + return healthy_deployments + + def weighted_shuffle_by_latency(self, items): + # Sort the items by latency + sorted_items = sorted(items, key=lambda x: x[1]) + # Get only the latencies + latencies = [i[1] for i in sorted_items] + # Calculate the sum of all latencies + total_latency = sum(latencies) + # Calculate the weight for each latency (lower latency = higher weight) + weights = [total_latency-latency for latency in latencies] + # Get a weighted random item + if sum(weights) == 0: + chosen_item = random.choice(sorted_items)[0] + else: + chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0] + return chosen_item + + def set_model_list(self, model_list: list): + self.model_list = model_list + self.model_names = [m["model_name"] for m in model_list] + + def get_model_names(self): + return self.model_names + + def print_verbose(self, print_statement): + if self.set_verbose: + print(f"LiteLLM.Router: {print_statement}") # noqa + + def get_available_deployment(self, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None): + """ + Returns the deployment based on routing strategy + """ + ## get healthy deployments + ### get all deployments + ### filter out the deployments currently cooling down + healthy_deployments = [m for m in self.model_list if m["model_name"] == model] + deployments_to_remove = [] + cooldown_deployments = self._get_cooldown_deployments() + self.print_verbose(f"cooldown deployments: {cooldown_deployments}") + ### FIND UNHEALTHY DEPLOYMENTS + for deployment in healthy_deployments: + deployment_name = deployment["litellm_params"]["model"] + if deployment_name in cooldown_deployments: + deployments_to_remove.append(deployment) + ### FILTER OUT UNHEALTHY DEPLOYMENTS + for deployment in deployments_to_remove: + healthy_deployments.remove(deployment) + self.print_verbose(f"healthy deployments: {healthy_deployments}") + if len(healthy_deployments) == 0: + raise ValueError("No models available") + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in + if self.routing_strategy == "least-busy": + if len(self.healthy_deployments) > 0: + for item in self.healthy_deployments: + if item[0]["model_name"] == model: # first one in queue will be the one with the most availability + return item[0] + else: + raise ValueError("No models available.") + elif self.routing_strategy == "simple-shuffle": + item = random.choice(healthy_deployments) + return item or item[0] + elif self.routing_strategy == "latency-based-routing": + returned_item = None + lowest_latency = float('inf') + ### shuffles with priority for lowest latency + # items_with_latencies = [('A', 10), ('B', 20), ('C', 30), ('D', 40)] + items_with_latencies = [] + for item in healthy_deployments: + items_with_latencies.append((item, self.deployment_latency_map[item["litellm_params"]["model"]])) + returned_item = self.weighted_shuffle_by_latency(items_with_latencies) + return returned_item + elif self.routing_strategy == "usage-based-routing": + return self.get_usage_based_available_deployment(model=model, messages=messages, input=input) + + raise ValueError("No models available.") + \ No newline at end of file diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py new file mode 100644 index 0000000000..87ac9d271a --- /dev/null +++ b/litellm/tests/test_router_fallbacks.py @@ -0,0 +1,76 @@ +#### What this tests #### +# This tests calling router with fallback models + +import sys, os, time +import traceback, asyncio +import pytest +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import litellm +from litellm import Router + +model_list = [ + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + { + "model_name": "azure/gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-functioncalling", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000 + } +] + + + +router = Router(model_list=model_list, fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]) + +kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]} + +def test_sync_fallbacks(): + try: + response = router.completion(**kwargs) + print(f"response: {response}") + except Exception as e: + print(e) + +def test_async_fallbacks(): + async def test_get_response(): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + response = await router.acompletion(**kwargs) + # response = await response + print(f"response: {response}") + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"An exception occurred: {e}") + + asyncio.run(test_get_response()) + +test_async_fallbacks() \ No newline at end of file