From 04f745e314556f8e21b1270cd60fa918b9a01d28 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 27 Nov 2023 17:35:02 -0800 Subject: [PATCH] fix(router.py): speed improvements to the router --- litellm/main.py | 2 +- litellm/router.py | 14 ++++- litellm/tests/test_profiling_router.py | 78 ++++++++++++++++++++++++++ litellm/utils.py | 3 +- 4 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 litellm/tests/test_profiling_router.py diff --git a/litellm/main.py b/litellm/main.py index 85091fec99..64fe498d97 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -178,7 +178,7 @@ async def acompletion(*args, **kwargs): response = completion(*args, **kwargs) else: # Await normally - init_response = completion(*args, **kwargs) + init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): diff --git a/litellm/router.py b/litellm/router.py index 3d43197cb6..7710f77875 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -94,6 +94,7 @@ class Router: # default litellm args self.default_litellm_params = default_litellm_params self.default_litellm_params["timeout"] = timeout + self.default_litellm_params["max_retries"] = 0 ### HEALTH CHECK THREAD ### @@ -278,8 +279,8 @@ class Router: If it fails after num_retries, fall back to another model group """ model_group = kwargs.get("model") - fallbacks = kwargs.pop("fallbacks", self.fallbacks) - context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks) + fallbacks = kwargs.get("fallbacks", self.fallbacks) + context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks) try: response = await self.async_function_with_retries(*args, **kwargs) self.print_verbose(f'Async Response: {response}') @@ -335,6 +336,8 @@ class Router: self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}") backoff_factor = 1 original_function = kwargs.pop("original_function") + fallbacks = kwargs.pop("fallbacks", self.fallbacks) + context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks) self.print_verbose(f"async function w/ retries: original_function - {original_function}") num_retries = kwargs.pop("num_retries") try: @@ -343,6 +346,11 @@ class Router: return response except Exception as e: original_exception = e + ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR + if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) + or (openai.RateLimitError and fallbacks is not None)): + raise original_exception + ### RETRY for current_attempt in range(num_retries): self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}") try: @@ -353,7 +361,7 @@ class Router: return response except openai.RateLimitError as e: - if num_retries > 0: + if num_retries > 0 and fallbacks is None: # on RateLimitError we'll wait for an exponential time before trying again await asyncio.sleep(backoff_factor) diff --git a/litellm/tests/test_profiling_router.py b/litellm/tests/test_profiling_router.py new file mode 100644 index 0000000000..01edbd85d9 --- /dev/null +++ b/litellm/tests/test_profiling_router.py @@ -0,0 +1,78 @@ +# #### What this tests #### +# # This profiles a router call to find where calls are taking the most time. + +# import sys, os, time, logging +# import traceback, asyncio, uuid +# import pytest +# import cProfile +# from pstats import Stats +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm +# from litellm import Router +# from concurrent.futures import ThreadPoolExecutor +# from dotenv import load_dotenv +# from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks + +# load_dotenv() + +# logging.basicConfig( +# level=logging.DEBUG, +# format='%(asctime)s %(levelname)s: %(message)s', +# datefmt='%I:%M:%S %p', +# filename='aiologs.log', # Name of the log file where logs will be written +# filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append +# ) + + +# model_list = [{ +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-turbo", +# "api_key": "6a5ae4c5b2bd4e8088248067799c6899", +# "api_base": "https://openai-france-1234.openai.azure.com" +# } +# }, { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-35-turbo", +# "api_key": "fe5b390e8990407e8d913f40833b19f7", +# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com" +# } +# }, { +# "model_name": "azure-model", +# "litellm_params": { +# "model": "azure/gpt-35-turbo", +# "api_key": "6a0f46e99d554e8caad9c2b7c0ba7319", +# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com" +# } +# }] + +# router = Router(model_list=model_list, set_verbose=False, num_retries=3) + +# async def router_completion(): +# try: +# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}] +# response = await router.acompletion(model="azure-model", messages=messages) +# # response = await litellm.acompletion(model="azure/gpt-35-turbo", messages=messages, api_key="6a0f46e99d554e8caad9c2b7c0ba7319", api_base="https://my-endpoint-canada-berri992.openai.azure.com") +# return response +# except Exception as e: +# print(e, file=sys.stderr) +# traceback.print_exc() +# return None + +# async def loadtest_fn(): +# start = time.time() +# n = 10 +# tasks = [router_completion() for _ in range(n)] +# chat_completions = await asyncio.gather(*tasks) +# successful_completions = [c for c in chat_completions if c is not None] +# print(n, time.time() - start, len(successful_completions)) + +# loop = asyncio.get_event_loop() +# loop.set_debug(True) +# log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds + +# # Excute the load testing function within the asyncio event loop +# loop.run_until_complete(loadtest_fn()) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index e1e9b22a42..573ea737cb 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1198,7 +1198,8 @@ def client(original_function): elif kwargs.get("messages", None): messages = kwargs["messages"] ### PRE-CALL RULES ### - rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model) + if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) and "content" in messages[0]: + rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model) elif call_type == CallTypes.embedding.value: messages = args[1] if len(args) > 1 else kwargs["input"] stream = True if "stream" in kwargs and kwargs["stream"] == True else False