From c273d6f0d6b5c7215cab56e89a5bffe73c16e077 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 23 Nov 2023 16:41:45 -0800 Subject: [PATCH] fix(router.py): add support for context window fallbacks on router --- litellm/proxy/utils.py | 99 +------------------------- litellm/router.py | 29 ++++++-- litellm/tests/test_router_fallbacks.py | 41 ++++++++++- 3 files changed, 65 insertions(+), 104 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d1d9554996..855fc6b1c4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,10 +1,11 @@ import litellm from litellm import ModelResponse from proxy_server import llm_model_list +from typing import Optional def track_cost_callback( kwargs, # kwargs to completion - completion_response: ModelResponse = None, # response from completion + completion_response: ModelResponse, # response from completion start_time = None, end_time = None, # start/end time for completion ): @@ -34,98 +35,4 @@ def track_cost_callback( response_cost = litellm.completion_cost(completion_response=completion_response) print("regular response_cost", response_cost) except: - pass - -# 1. `--experimental_async` starts 2 background threads: -# - 1. to check the redis queue: -# - if job available -# - it dequeues as many jobs as healthy endpoints -# - calls llm api -> saves response in redis cache -# - 2. to check the llm apis: -# - check if endpoints are healthy (unhealthy = 4xx / 5xx call or >1min. queue) -# - which one is least busy -# 2. /router/chat/completions: receives request -> adds to redis queue -> returns {run_id, started_at, request_obj} -# 3. /router/chat/completions/runs/{run_id}: returns {status: _, [optional] response_obj: _} -# """ - -# def _start_health_check_thread(): -# """ -# Starts a separate thread to perform health checks periodically. -# """ -# health_check_thread = threading.Thread(target=_perform_health_checks, daemon=True) -# health_check_thread.start() -# llm_call_thread = threading.Thread(target=_llm_call_thread, daemon=True) -# llm_call_thread.start() - - -# def _llm_call_thread(): -# """ -# Periodically performs job checks on the redis queue. -# If available, make llm api calls. -# Write result to redis cache (1 min ttl) -# """ -# with concurrent.futures.ThreadPoolExecutor() as executor: -# while True: -# job_checks = _job_check() -# future_to_job = {executor.submit(_llm_api_call, job): job for job in job_checks} -# for future in concurrent.futures.as_completed(future_to_job): -# job = future_to_job[future] -# try: -# result = future.result() -# except Exception as exc: -# print(f'{job} generated an exception: {exc}') -# else: -# _write_to_cache(job, result, ttl=1*60) -# time.sleep(1) # sleep 1 second to avoid overloading the server - - - -# def _perform_health_checks(): -# """ -# Periodically performs health checks on the servers. -# Updates the list of healthy servers accordingly. -# """ -# while True: -# healthy_deployments = _health_check() -# # Adjust the time interval based on your needs -# time.sleep(15) - -# def _job_check(): -# """ -# Periodically performs job checks on the redis queue. -# Returns the list of available jobs - len(available_jobs) == len(healthy_endpoints), -# e.g. don't dequeue a gpt-3.5-turbo job if there's no healthy deployments left -# """ -# pass - -# def _llm_api_call(**data): -# """ -# Makes the litellm.completion() call with 3 retries -# """ -# return litellm.completion(num_retries=3, **data) - -# def _write_to_cache(): -# """ -# Writes the result to a redis cache in the form (key:job_id, value: ) -# """ -# pass - -# def _health_check(): -# """ -# Performs a health check on the deployments -# Returns the list of healthy deployments -# """ -# healthy_deployments = [] -# for deployment in model_list: -# litellm_args = deployment["litellm_params"] -# try: -# start_time = time.time() -# litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond -# end_time = time.time() -# response_time = end_time - start_time -# logging.debug(f"response_time: {response_time}") -# healthy_deployments.append((deployment, response_time)) -# healthy_deployments.sort(key=lambda x: x[1]) -# except Exception as e: -# pass -# return healthy_deployments + pass \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 9341f2b76b..f93897e886 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -280,10 +280,15 @@ class Router: try: self.print_verbose(f"Trying to fallback b/w models") if isinstance(e, litellm.ContextWindowExceededError): - for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}] + fallback_model_group = None + for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break + + if fallback_model_group is None: + raise original_exception + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment @@ -360,6 +365,7 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") + try: response = self.function_with_retries(*args, **kwargs) self.print_verbose(f'Response: {response}') @@ -368,36 +374,47 @@ class Router: original_exception = e self.print_verbose(f"An exception occurs{original_exception}") try: - self.print_verbose(f"Trying to fallback b/w models") - fallback_model_group = [] + self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}") + self.print_verbose(f"Type of exception: {type(e)}; error_message: {str(e)}") if isinstance(e, litellm.ContextWindowExceededError): - for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}] + self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}") + fallback_model_group = None + for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}] if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break + + if fallback_model_group is None: + raise original_exception + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: kwargs["model"] = mg - response = self.function_with_retries(*args, **kwargs) + response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: pass else: self.print_verbose(f"inside model fallbacks: {self.fallbacks}") + fallback_model_group = None for item in self.fallbacks: if list(item.keys())[0] == model_group: fallback_model_group = item[model_group] break + + if fallback_model_group is None: + raise original_exception + for mg in fallback_model_group: """ Iterate through the model groups and try calling that deployment """ try: kwargs["model"] = mg - response = self.function_with_retries(*args, **kwargs) + response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: pass diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 5f4cffab84..abb7e5eb87 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -23,6 +23,17 @@ model_list = [ "tpm": 240000, "rpm": 1800 }, + { # list of model deployments + "model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, { "model_name": "azure/gpt-3.5-turbo", # openai model name "litellm_params": { # params for litellm completion/embedding call @@ -42,21 +53,35 @@ model_list = [ }, "tpm": 1000000, "rpm": 9000 + }, + { + "model_name": "gpt-3.5-turbo-16k", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "gpt-3.5-turbo-16k", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 1000000, + "rpm": 9000 } ] -router = Router(model_list=model_list, fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]) +router = Router(model_list=model_list, + fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}], + context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}], + set_verbose=True) kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]} def test_sync_fallbacks(): try: + litellm.set_verbose = True response = router.completion(**kwargs) print(f"response: {response}") except Exception as e: print(e) +test_sync_fallbacks() def test_async_fallbacks(): litellm.set_verbose = False @@ -74,4 +99,16 @@ def test_async_fallbacks(): asyncio.run(test_get_response()) -# test_async_fallbacks() \ No newline at end of file +# test_async_fallbacks() + +def test_sync_context_window_fallbacks(): + try: + sample_text = "Say error 50 times" * 10000 + kwargs["model"] = "azure/gpt-3.5-turbo-context-fallback" + kwargs["messages"] = [{"role": "user", "content": sample_text}] + response = router.completion(**kwargs) + print(f"response: {response}") + except Exception as e: + print(e) + +# test_sync_context_window_fallbacks() \ No newline at end of file