mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
fix(router.py): speed improvements to the router
This commit is contained in:
parent
8560794963
commit
04f745e314
4 changed files with 92 additions and 5 deletions
|
@ -178,7 +178,7 @@ async def acompletion(*args, **kwargs):
|
|||
response = completion(*args, **kwargs)
|
||||
else:
|
||||
# Await normally
|
||||
init_response = completion(*args, **kwargs)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
|
|
|
@ -94,6 +94,7 @@ class Router:
|
|||
# default litellm args
|
||||
self.default_litellm_params = default_litellm_params
|
||||
self.default_litellm_params["timeout"] = timeout
|
||||
self.default_litellm_params["max_retries"] = 0
|
||||
|
||||
|
||||
### HEALTH CHECK THREAD ###
|
||||
|
@ -278,8 +279,8 @@ class Router:
|
|||
If it fails after num_retries, fall back to another model group
|
||||
"""
|
||||
model_group = kwargs.get("model")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.pop("context_window_fallbacks", self.context_window_fallbacks)
|
||||
fallbacks = kwargs.get("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
|
||||
try:
|
||||
response = await self.async_function_with_retries(*args, **kwargs)
|
||||
self.print_verbose(f'Async Response: {response}')
|
||||
|
@ -335,6 +336,8 @@ class Router:
|
|||
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
context_window_fallbacks = kwargs.get("context_window_fallbacks", self.context_window_fallbacks)
|
||||
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
try:
|
||||
|
@ -343,6 +346,11 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||
or (openai.RateLimitError and fallbacks is not None)):
|
||||
raise original_exception
|
||||
### RETRY
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||
try:
|
||||
|
@ -353,7 +361,7 @@ class Router:
|
|||
return response
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
if num_retries > 0:
|
||||
if num_retries > 0 and fallbacks is None:
|
||||
# on RateLimitError we'll wait for an exponential time before trying again
|
||||
await asyncio.sleep(backoff_factor)
|
||||
|
||||
|
|
78
litellm/tests/test_profiling_router.py
Normal file
78
litellm/tests/test_profiling_router.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
# #### What this tests ####
|
||||
# # This profiles a router call to find where calls are taking the most time.
|
||||
|
||||
# import sys, os, time, logging
|
||||
# import traceback, asyncio, uuid
|
||||
# import pytest
|
||||
# import cProfile
|
||||
# from pstats import Stats
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import litellm
|
||||
# from litellm import Router
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
# from dotenv import load_dotenv
|
||||
# from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks
|
||||
|
||||
# load_dotenv()
|
||||
|
||||
# logging.basicConfig(
|
||||
# level=logging.DEBUG,
|
||||
# format='%(asctime)s %(levelname)s: %(message)s',
|
||||
# datefmt='%I:%M:%S %p',
|
||||
# filename='aiologs.log', # Name of the log file where logs will be written
|
||||
# filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append
|
||||
# )
|
||||
|
||||
|
||||
# model_list = [{
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-turbo",
|
||||
# "api_key": "6a5ae4c5b2bd4e8088248067799c6899",
|
||||
# "api_base": "https://openai-france-1234.openai.azure.com"
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "fe5b390e8990407e8d913f40833b19f7",
|
||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com"
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "6a0f46e99d554e8caad9c2b7c0ba7319",
|
||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com"
|
||||
# }
|
||||
# }]
|
||||
|
||||
# router = Router(model_list=model_list, set_verbose=False, num_retries=3)
|
||||
|
||||
# async def router_completion():
|
||||
# try:
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
||||
# response = await router.acompletion(model="azure-model", messages=messages)
|
||||
# # response = await litellm.acompletion(model="azure/gpt-35-turbo", messages=messages, api_key="6a0f46e99d554e8caad9c2b7c0ba7319", api_base="https://my-endpoint-canada-berri992.openai.azure.com")
|
||||
# return response
|
||||
# except Exception as e:
|
||||
# print(e, file=sys.stderr)
|
||||
# traceback.print_exc()
|
||||
# return None
|
||||
|
||||
# async def loadtest_fn():
|
||||
# start = time.time()
|
||||
# n = 10
|
||||
# tasks = [router_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# print(n, time.time() - start, len(successful_completions))
|
||||
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.set_debug(True)
|
||||
# log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds
|
||||
|
||||
# # Excute the load testing function within the asyncio event loop
|
||||
# loop.run_until_complete(loadtest_fn())
|
|
@ -1198,7 +1198,8 @@ def client(original_function):
|
|||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
### PRE-CALL RULES ###
|
||||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
|
||||
if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) and "content" in messages[0]:
|
||||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
|
||||
elif call_type == CallTypes.embedding.value:
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue