mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(router.py): enable fallbacks for sync completions
This commit is contained in:
parent
288e3e962a
commit
59d084342d
2 changed files with 457 additions and 207 deletions
|
@ -9,28 +9,47 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal
|
from typing import Dict, List, Optional, Union, Literal
|
||||||
import random, threading, time
|
import random, threading, time, traceback
|
||||||
import litellm, openai
|
import litellm, openai
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||||
import logging, asyncio
|
import logging, asyncio
|
||||||
import inspect
|
import inspect, concurrent
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
"""
|
"""
|
||||||
Example usage:
|
Example usage:
|
||||||
|
```python
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
model_list = [{
|
model_list = [
|
||||||
"model_name": "gpt-3.5-turbo", # model alias
|
{
|
||||||
|
"model_name": "azure-gpt-3.5-turbo", # model alias
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "azure/<your-deployment-name>",
|
"model": "azure/<your-deployment-name-1>",
|
||||||
"api_key": <your-api-key>,
|
"api_key": <your-api-key>,
|
||||||
"api_version": <your-api-version>,
|
"api_version": <your-api-version>,
|
||||||
"api_base": <your-api-base>
|
"api_base": <your-api-base>
|
||||||
},
|
},
|
||||||
}]
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-gpt-3.5-turbo", # model alias
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/<your-deployment-name-2>",
|
||||||
|
"api_key": <your-api-key>,
|
||||||
|
"api_version": <your-api-version>,
|
||||||
|
"api_base": <your-api-base>
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "openai-gpt-3.5-turbo", # model alias
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": <your-api-key>,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
model_names: List = []
|
model_names: List = []
|
||||||
cache_responses: bool = False
|
cache_responses: bool = False
|
||||||
|
@ -48,6 +67,8 @@ class Router:
|
||||||
timeout: float = 600,
|
timeout: float = 600,
|
||||||
default_litellm_params = {}, # default params for Router.chat.completion.create
|
default_litellm_params = {}, # default params for Router.chat.completion.create
|
||||||
set_verbose: bool = False,
|
set_verbose: bool = False,
|
||||||
|
fallbacks: List = [],
|
||||||
|
context_window_fallbacks: List = [],
|
||||||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||||
|
|
||||||
if model_list:
|
if model_list:
|
||||||
|
@ -60,12 +81,19 @@ class Router:
|
||||||
|
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
|
self.timeout = timeout
|
||||||
|
self.routing_strategy = routing_strategy
|
||||||
|
self.fallbacks = fallbacks
|
||||||
|
self.context_window_fallbacks = context_window_fallbacks
|
||||||
|
|
||||||
|
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||||
self.chat = litellm.Chat(params=default_litellm_params)
|
self.chat = litellm.Chat(params=default_litellm_params)
|
||||||
|
|
||||||
|
# default litellm args
|
||||||
self.default_litellm_params = default_litellm_params
|
self.default_litellm_params = default_litellm_params
|
||||||
self.default_litellm_params["timeout"] = timeout
|
self.default_litellm_params["timeout"] = timeout
|
||||||
|
|
||||||
self.routing_strategy = routing_strategy
|
|
||||||
### HEALTH CHECK THREAD ###
|
### HEALTH CHECK THREAD ###
|
||||||
if self.routing_strategy == "least-busy":
|
if self.routing_strategy == "least-busy":
|
||||||
self._start_health_check_thread()
|
self._start_health_check_thread()
|
||||||
|
@ -99,192 +127,7 @@ class Router:
|
||||||
else:
|
else:
|
||||||
litellm.failure_callback = [self.deployment_callback_on_failure]
|
litellm.failure_callback = [self.deployment_callback_on_failure]
|
||||||
|
|
||||||
def _start_health_check_thread(self):
|
|
||||||
"""
|
|
||||||
Starts a separate thread to perform health checks periodically.
|
|
||||||
"""
|
|
||||||
health_check_thread = threading.Thread(target=self._perform_health_checks, daemon=True)
|
|
||||||
health_check_thread.start()
|
|
||||||
|
|
||||||
def _perform_health_checks(self):
|
|
||||||
"""
|
|
||||||
Periodically performs health checks on the servers.
|
|
||||||
Updates the list of healthy servers accordingly.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
self.healthy_deployments = self._health_check()
|
|
||||||
# Adjust the time interval based on your needs
|
|
||||||
time.sleep(15)
|
|
||||||
|
|
||||||
def _health_check(self):
|
|
||||||
"""
|
|
||||||
Performs a health check on the deployments
|
|
||||||
Returns the list of healthy deployments
|
|
||||||
"""
|
|
||||||
healthy_deployments = []
|
|
||||||
for deployment in self.model_list:
|
|
||||||
litellm_args = deployment["litellm_params"]
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond
|
|
||||||
end_time = time.time()
|
|
||||||
response_time = end_time - start_time
|
|
||||||
logging.debug(f"response_time: {response_time}")
|
|
||||||
healthy_deployments.append((deployment, response_time))
|
|
||||||
healthy_deployments.sort(key=lambda x: x[1])
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
return healthy_deployments
|
|
||||||
|
|
||||||
def weighted_shuffle_by_latency(self, items):
|
|
||||||
# Sort the items by latency
|
|
||||||
sorted_items = sorted(items, key=lambda x: x[1])
|
|
||||||
# Get only the latencies
|
|
||||||
latencies = [i[1] for i in sorted_items]
|
|
||||||
# Calculate the sum of all latencies
|
|
||||||
total_latency = sum(latencies)
|
|
||||||
# Calculate the weight for each latency (lower latency = higher weight)
|
|
||||||
weights = [total_latency-latency for latency in latencies]
|
|
||||||
# Get a weighted random item
|
|
||||||
if sum(weights) == 0:
|
|
||||||
chosen_item = random.choice(sorted_items)[0]
|
|
||||||
else:
|
|
||||||
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
|
|
||||||
return chosen_item
|
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
|
||||||
self.model_list = model_list
|
|
||||||
self.model_names = [m["model_name"] for m in model_list]
|
|
||||||
|
|
||||||
def get_model_names(self):
|
|
||||||
return self.model_names
|
|
||||||
|
|
||||||
def print_verbose(self, print_statement):
|
|
||||||
if self.set_verbose:
|
|
||||||
print(f"LiteLLM.Router: {print_statement}") # noqa
|
|
||||||
|
|
||||||
def get_available_deployment(self,
|
|
||||||
model: str,
|
|
||||||
messages: Optional[List[Dict[str, str]]] = None,
|
|
||||||
input: Optional[Union[str, List]] = None):
|
|
||||||
"""
|
|
||||||
Returns the deployment based on routing strategy
|
|
||||||
"""
|
|
||||||
## get healthy deployments
|
|
||||||
### get all deployments
|
|
||||||
### filter out the deployments currently cooling down
|
|
||||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
|
||||||
deployments_to_remove = []
|
|
||||||
cooldown_deployments = self._get_cooldown_deployments()
|
|
||||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
|
||||||
### FIND UNHEALTHY DEPLOYMENTS
|
|
||||||
for deployment in healthy_deployments:
|
|
||||||
deployment_name = deployment["litellm_params"]["model"]
|
|
||||||
if deployment_name in cooldown_deployments:
|
|
||||||
deployments_to_remove.append(deployment)
|
|
||||||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
|
||||||
for deployment in deployments_to_remove:
|
|
||||||
healthy_deployments.remove(deployment)
|
|
||||||
self.print_verbose(f"healthy deployments: {healthy_deployments}")
|
|
||||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
|
||||||
model = litellm.model_alias_map[
|
|
||||||
model
|
|
||||||
] # update the model to the actual value if an alias has been passed in
|
|
||||||
if self.routing_strategy == "least-busy":
|
|
||||||
if len(self.healthy_deployments) > 0:
|
|
||||||
for item in self.healthy_deployments:
|
|
||||||
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
|
||||||
return item[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("No models available.")
|
|
||||||
elif self.routing_strategy == "simple-shuffle":
|
|
||||||
item = random.choice(healthy_deployments)
|
|
||||||
return item or item[0]
|
|
||||||
elif self.routing_strategy == "latency-based-routing":
|
|
||||||
returned_item = None
|
|
||||||
lowest_latency = float('inf')
|
|
||||||
### shuffles with priority for lowest latency
|
|
||||||
# items_with_latencies = [('A', 10), ('B', 20), ('C', 30), ('D', 40)]
|
|
||||||
items_with_latencies = []
|
|
||||||
for item in healthy_deployments:
|
|
||||||
items_with_latencies.append((item, self.deployment_latency_map[item["litellm_params"]["model"]]))
|
|
||||||
returned_item = self.weighted_shuffle_by_latency(items_with_latencies)
|
|
||||||
return returned_item
|
|
||||||
elif self.routing_strategy == "usage-based-routing":
|
|
||||||
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
|
|
||||||
|
|
||||||
raise ValueError("No models available.")
|
|
||||||
|
|
||||||
def retry_if_rate_limit_error(self, exception):
|
|
||||||
return isinstance(exception, openai.RateLimitError)
|
|
||||||
|
|
||||||
def retry_if_api_error(self, exception):
|
|
||||||
return isinstance(exception, openai.APIError)
|
|
||||||
|
|
||||||
async def async_function_with_retries(self, *args, **kwargs):
|
|
||||||
# we'll backoff exponentially with each retry
|
|
||||||
backoff_factor = 1
|
|
||||||
original_exception = kwargs.pop("original_exception")
|
|
||||||
original_function = kwargs.pop("original_function")
|
|
||||||
for current_attempt in range(self.num_retries):
|
|
||||||
try:
|
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
|
||||||
response = await original_function(*args, **kwargs)
|
|
||||||
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
|
|
||||||
response = await response
|
|
||||||
return response
|
|
||||||
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
# on RateLimitError we'll wait for an exponential time before trying again
|
|
||||||
await asyncio.sleep(backoff_factor)
|
|
||||||
|
|
||||||
# increase backoff factor for next run
|
|
||||||
backoff_factor *= 2
|
|
||||||
|
|
||||||
except openai.APIError as e:
|
|
||||||
# on APIError we immediately retry without any wait, change this if necessary
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# for any other exception types, don't retry
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def function_with_retries(self, *args, **kwargs):
|
|
||||||
# we'll backoff exponentially with each retry
|
|
||||||
self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}")
|
|
||||||
backoff_factor = 1
|
|
||||||
original_function = kwargs.pop("original_function")
|
|
||||||
num_retries = kwargs.pop("num_retries")
|
|
||||||
try:
|
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
|
||||||
response = original_function(*args, **kwargs)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
for current_attempt in range(num_retries):
|
|
||||||
num_retries -= 1 # decrement the number of retries
|
|
||||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
|
||||||
try:
|
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
|
||||||
response = original_function(*args, **kwargs)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
if num_retries > 0:
|
|
||||||
# on RateLimitError we'll wait for an exponential time before trying again
|
|
||||||
time.sleep(backoff_factor)
|
|
||||||
|
|
||||||
# increase backoff factor for next run
|
|
||||||
backoff_factor *= 2
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# for any other exception types, immediately retry
|
|
||||||
if num_retries > 0:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
### COMPLETION + EMBEDDING FUNCTIONS
|
### COMPLETION + EMBEDDING FUNCTIONS
|
||||||
|
|
||||||
def completion(self,
|
def completion(self,
|
||||||
|
@ -300,7 +143,12 @@ class Router:
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._completion
|
kwargs["original_function"] = self._completion
|
||||||
kwargs["num_retries"] = self.num_retries
|
kwargs["num_retries"] = self.num_retries
|
||||||
return self.function_with_retries(**kwargs)
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
# Submit the function to the executor with a timeout
|
||||||
|
future = executor.submit(self.function_with_fallbacks, **kwargs)
|
||||||
|
response = future.result(timeout=self.timeout)
|
||||||
|
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -322,18 +170,38 @@ class Router:
|
||||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def acompletion(self,
|
async def acompletion(self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
is_retry: Optional[bool] = False,
|
**kwargs):
|
||||||
is_fallback: Optional[bool] = False,
|
try:
|
||||||
**kwargs):
|
kwargs["model"] = model
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
kwargs["original_function"] = self._completion
|
||||||
|
kwargs["num_retries"] = self.num_retries
|
||||||
|
|
||||||
|
# Use asyncio.timeout to enforce the timeout
|
||||||
|
async with asyncio.timeout(self.timeout):
|
||||||
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
**kwargs):
|
||||||
try:
|
try:
|
||||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
data[k] = v
|
||||||
|
self.print_verbose(f"acompletion model: {data['model']}")
|
||||||
|
|
||||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -345,7 +213,7 @@ class Router:
|
||||||
return await self.async_function_with_retries(**kwargs)
|
return await self.async_function_with_retries(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def text_completion(self,
|
def text_completion(self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -403,6 +271,190 @@ class Router:
|
||||||
data[k] = v
|
data[k] = v
|
||||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
||||||
|
|
||||||
|
async def async_function_with_fallbacks(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Try calling the function_with_retries
|
||||||
|
If it fails after num_retries, fall back to another model group
|
||||||
|
"""
|
||||||
|
model_group = kwargs.get("model")
|
||||||
|
try:
|
||||||
|
response = await self.async_function_with_retries(*args, **kwargs)
|
||||||
|
self.print_verbose(f'Response: {response}')
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(f"An exception occurs")
|
||||||
|
original_exception = e
|
||||||
|
try:
|
||||||
|
self.print_verbose(f"Trying to fallback b/w models")
|
||||||
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
|
for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||||
|
if list(item.keys())[0] == model_group:
|
||||||
|
fallback_model_group = item[model_group]
|
||||||
|
break
|
||||||
|
for mg in fallback_model_group:
|
||||||
|
"""
|
||||||
|
Iterate through the model groups and try calling that deployment
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kwargs["model"] = mg
|
||||||
|
response = await self.async_function_with_retries(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
|
||||||
|
for item in self.fallbacks:
|
||||||
|
if list(item.keys())[0] == model_group:
|
||||||
|
fallback_model_group = item[model_group]
|
||||||
|
break
|
||||||
|
for mg in fallback_model_group:
|
||||||
|
"""
|
||||||
|
Iterate through the model groups and try calling that deployment
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kwargs["model"] = mg
|
||||||
|
response = await self.async_function_with_retries(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(f"An exception occurred - {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
|
async def async_function_with_retries(self, *args, **kwargs):
|
||||||
|
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||||
|
backoff_factor = 1
|
||||||
|
original_function = kwargs.pop("original_function")
|
||||||
|
num_retries = kwargs.pop("num_retries")
|
||||||
|
try:
|
||||||
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
|
response = await original_function(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
for current_attempt in range(num_retries):
|
||||||
|
num_retries -= 1 # decrement the number of retries
|
||||||
|
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||||
|
try:
|
||||||
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
|
response = await original_function(*args, **kwargs)
|
||||||
|
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
|
||||||
|
response = await response
|
||||||
|
return response
|
||||||
|
|
||||||
|
except openai.RateLimitError as e:
|
||||||
|
if num_retries > 0:
|
||||||
|
# on RateLimitError we'll wait for an exponential time before trying again
|
||||||
|
await asyncio.sleep(backoff_factor)
|
||||||
|
|
||||||
|
# increase backoff factor for next run
|
||||||
|
backoff_factor *= 2
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# for any other exception types, immediately retry
|
||||||
|
if num_retries > 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def function_with_fallbacks(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Try calling the function_with_retries
|
||||||
|
If it fails after num_retries, fall back to another model group
|
||||||
|
"""
|
||||||
|
model_group = kwargs.get("model")
|
||||||
|
try:
|
||||||
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
|
self.print_verbose(f'Response: {response}')
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(f"An exception occurs")
|
||||||
|
original_exception = e
|
||||||
|
try:
|
||||||
|
self.print_verbose(f"Trying to fallback b/w models")
|
||||||
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
|
for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||||
|
if list(item.keys())[0] == model_group:
|
||||||
|
fallback_model_group = item[model_group]
|
||||||
|
break
|
||||||
|
for mg in fallback_model_group:
|
||||||
|
"""
|
||||||
|
Iterate through the model groups and try calling that deployment
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kwargs["model"] = mg
|
||||||
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
|
||||||
|
for item in self.fallbacks:
|
||||||
|
if list(item.keys())[0] == model_group:
|
||||||
|
fallback_model_group = item[model_group]
|
||||||
|
break
|
||||||
|
for mg in fallback_model_group:
|
||||||
|
"""
|
||||||
|
Iterate through the model groups and try calling that deployment
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kwargs["model"] = mg
|
||||||
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(f"An exception occurred - {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def function_with_retries(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Try calling the model 3 times. Shuffle between available deployments.
|
||||||
|
"""
|
||||||
|
self.print_verbose(f"Inside function with retries: args - {args}; kwargs - {kwargs}")
|
||||||
|
backoff_factor = 1
|
||||||
|
original_function = kwargs.pop("original_function")
|
||||||
|
num_retries = kwargs.pop("num_retries")
|
||||||
|
try:
|
||||||
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
|
response = original_function(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.print_verbose(f"num retries in function with retries: {num_retries}")
|
||||||
|
for current_attempt in range(num_retries):
|
||||||
|
num_retries -= 1 # decrement the number of retries
|
||||||
|
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||||
|
try:
|
||||||
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
|
response = original_function(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except openai.RateLimitError as e:
|
||||||
|
if num_retries > 0:
|
||||||
|
# on RateLimitError we'll wait for an exponential time before trying again
|
||||||
|
time.sleep(backoff_factor)
|
||||||
|
|
||||||
|
# increase backoff factor for next run
|
||||||
|
backoff_factor *= 2
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# for any other exception types, immediately retry
|
||||||
|
if num_retries > 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
if self.num_retries == 0:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
### HELPER FUNCTIONS
|
||||||
|
|
||||||
def deployment_callback(
|
def deployment_callback(
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
|
@ -433,12 +485,15 @@ class Router:
|
||||||
completion_response, # response from completion
|
completion_response, # response from completion
|
||||||
start_time, end_time # start/end time
|
start_time, end_time # start/end time
|
||||||
):
|
):
|
||||||
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
try:
|
||||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
||||||
if custom_llm_provider:
|
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||||
model_name = f"{custom_llm_provider}/{model_name}"
|
if custom_llm_provider:
|
||||||
|
model_name = f"{custom_llm_provider}/{model_name}"
|
||||||
self._set_cooldown_deployments(model_name)
|
|
||||||
|
self._set_cooldown_deployments(model_name)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def _set_cooldown_deployments(self,
|
def _set_cooldown_deployments(self,
|
||||||
deployment: str):
|
deployment: str):
|
||||||
|
@ -577,4 +632,123 @@ class Router:
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
self.increment(tpm_key, total_tokens)
|
self.increment(tpm_key, total_tokens)
|
||||||
self.increment(rpm_key, 1)
|
self.increment(rpm_key, 1)
|
||||||
|
|
||||||
|
def _start_health_check_thread(self):
|
||||||
|
"""
|
||||||
|
Starts a separate thread to perform health checks periodically.
|
||||||
|
"""
|
||||||
|
health_check_thread = threading.Thread(target=self._perform_health_checks, daemon=True)
|
||||||
|
health_check_thread.start()
|
||||||
|
|
||||||
|
def _perform_health_checks(self):
|
||||||
|
"""
|
||||||
|
Periodically performs health checks on the servers.
|
||||||
|
Updates the list of healthy servers accordingly.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
self.healthy_deployments = self._health_check()
|
||||||
|
# Adjust the time interval based on your needs
|
||||||
|
time.sleep(15)
|
||||||
|
|
||||||
|
def _health_check(self):
|
||||||
|
"""
|
||||||
|
Performs a health check on the deployments
|
||||||
|
Returns the list of healthy deployments
|
||||||
|
"""
|
||||||
|
healthy_deployments = []
|
||||||
|
for deployment in self.model_list:
|
||||||
|
litellm_args = deployment["litellm_params"]
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond
|
||||||
|
end_time = time.time()
|
||||||
|
response_time = end_time - start_time
|
||||||
|
logging.debug(f"response_time: {response_time}")
|
||||||
|
healthy_deployments.append((deployment, response_time))
|
||||||
|
healthy_deployments.sort(key=lambda x: x[1])
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
def weighted_shuffle_by_latency(self, items):
|
||||||
|
# Sort the items by latency
|
||||||
|
sorted_items = sorted(items, key=lambda x: x[1])
|
||||||
|
# Get only the latencies
|
||||||
|
latencies = [i[1] for i in sorted_items]
|
||||||
|
# Calculate the sum of all latencies
|
||||||
|
total_latency = sum(latencies)
|
||||||
|
# Calculate the weight for each latency (lower latency = higher weight)
|
||||||
|
weights = [total_latency-latency for latency in latencies]
|
||||||
|
# Get a weighted random item
|
||||||
|
if sum(weights) == 0:
|
||||||
|
chosen_item = random.choice(sorted_items)[0]
|
||||||
|
else:
|
||||||
|
chosen_item = random.choices(sorted_items, weights=weights, k=1)[0][0]
|
||||||
|
return chosen_item
|
||||||
|
|
||||||
|
def set_model_list(self, model_list: list):
|
||||||
|
self.model_list = model_list
|
||||||
|
self.model_names = [m["model_name"] for m in model_list]
|
||||||
|
|
||||||
|
def get_model_names(self):
|
||||||
|
return self.model_names
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement):
|
||||||
|
if self.set_verbose:
|
||||||
|
print(f"LiteLLM.Router: {print_statement}") # noqa
|
||||||
|
|
||||||
|
def get_available_deployment(self,
|
||||||
|
model: str,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
input: Optional[Union[str, List]] = None):
|
||||||
|
"""
|
||||||
|
Returns the deployment based on routing strategy
|
||||||
|
"""
|
||||||
|
## get healthy deployments
|
||||||
|
### get all deployments
|
||||||
|
### filter out the deployments currently cooling down
|
||||||
|
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||||
|
deployments_to_remove = []
|
||||||
|
cooldown_deployments = self._get_cooldown_deployments()
|
||||||
|
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||||
|
### FIND UNHEALTHY DEPLOYMENTS
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
deployment_name = deployment["litellm_params"]["model"]
|
||||||
|
if deployment_name in cooldown_deployments:
|
||||||
|
deployments_to_remove.append(deployment)
|
||||||
|
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
||||||
|
for deployment in deployments_to_remove:
|
||||||
|
healthy_deployments.remove(deployment)
|
||||||
|
self.print_verbose(f"healthy deployments: {healthy_deployments}")
|
||||||
|
if len(healthy_deployments) == 0:
|
||||||
|
raise ValueError("No models available")
|
||||||
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[
|
||||||
|
model
|
||||||
|
] # update the model to the actual value if an alias has been passed in
|
||||||
|
if self.routing_strategy == "least-busy":
|
||||||
|
if len(self.healthy_deployments) > 0:
|
||||||
|
for item in self.healthy_deployments:
|
||||||
|
if item[0]["model_name"] == model: # first one in queue will be the one with the most availability
|
||||||
|
return item[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("No models available.")
|
||||||
|
elif self.routing_strategy == "simple-shuffle":
|
||||||
|
item = random.choice(healthy_deployments)
|
||||||
|
return item or item[0]
|
||||||
|
elif self.routing_strategy == "latency-based-routing":
|
||||||
|
returned_item = None
|
||||||
|
lowest_latency = float('inf')
|
||||||
|
### shuffles with priority for lowest latency
|
||||||
|
# items_with_latencies = [('A', 10), ('B', 20), ('C', 30), ('D', 40)]
|
||||||
|
items_with_latencies = []
|
||||||
|
for item in healthy_deployments:
|
||||||
|
items_with_latencies.append((item, self.deployment_latency_map[item["litellm_params"]["model"]]))
|
||||||
|
returned_item = self.weighted_shuffle_by_latency(items_with_latencies)
|
||||||
|
return returned_item
|
||||||
|
elif self.routing_strategy == "usage-based-routing":
|
||||||
|
return self.get_usage_based_available_deployment(model=model, messages=messages, input=input)
|
||||||
|
|
||||||
|
raise ValueError("No models available.")
|
||||||
|
|
76
litellm/tests/test_router_fallbacks.py
Normal file
76
litellm/tests/test_router_fallbacks.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests calling router with fallback models
|
||||||
|
|
||||||
|
import sys, os, time
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}])
|
||||||
|
|
||||||
|
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
|
||||||
|
|
||||||
|
def test_sync_fallbacks():
|
||||||
|
try:
|
||||||
|
response = router.completion(**kwargs)
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
def test_async_fallbacks():
|
||||||
|
async def test_get_response():
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
try:
|
||||||
|
response = await router.acompletion(**kwargs)
|
||||||
|
# response = await response
|
||||||
|
print(f"response: {response}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
test_async_fallbacks()
|
Loading…
Add table
Add a link
Reference in a new issue