mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(router.py): add support for context window fallbacks on router
This commit is contained in:
parent
49a6ebfa30
commit
e4deb09eb6
3 changed files with 65 additions and 104 deletions
|
@ -1,10 +1,11 @@
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import ModelResponse
|
from litellm import ModelResponse
|
||||||
from proxy_server import llm_model_list
|
from proxy_server import llm_model_list
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
def track_cost_callback(
|
def track_cost_callback(
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
completion_response: ModelResponse = None, # response from completion
|
completion_response: ModelResponse, # response from completion
|
||||||
start_time = None,
|
start_time = None,
|
||||||
end_time = None, # start/end time for completion
|
end_time = None, # start/end time for completion
|
||||||
):
|
):
|
||||||
|
@ -34,98 +35,4 @@ def track_cost_callback(
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
print("regular response_cost", response_cost)
|
print("regular response_cost", response_cost)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 1. `--experimental_async` starts 2 background threads:
|
|
||||||
# - 1. to check the redis queue:
|
|
||||||
# - if job available
|
|
||||||
# - it dequeues as many jobs as healthy endpoints
|
|
||||||
# - calls llm api -> saves response in redis cache
|
|
||||||
# - 2. to check the llm apis:
|
|
||||||
# - check if endpoints are healthy (unhealthy = 4xx / 5xx call or >1min. queue)
|
|
||||||
# - which one is least busy
|
|
||||||
# 2. /router/chat/completions: receives request -> adds to redis queue -> returns {run_id, started_at, request_obj}
|
|
||||||
# 3. /router/chat/completions/runs/{run_id}: returns {status: _, [optional] response_obj: _}
|
|
||||||
# """
|
|
||||||
|
|
||||||
# def _start_health_check_thread():
|
|
||||||
# """
|
|
||||||
# Starts a separate thread to perform health checks periodically.
|
|
||||||
# """
|
|
||||||
# health_check_thread = threading.Thread(target=_perform_health_checks, daemon=True)
|
|
||||||
# health_check_thread.start()
|
|
||||||
# llm_call_thread = threading.Thread(target=_llm_call_thread, daemon=True)
|
|
||||||
# llm_call_thread.start()
|
|
||||||
|
|
||||||
|
|
||||||
# def _llm_call_thread():
|
|
||||||
# """
|
|
||||||
# Periodically performs job checks on the redis queue.
|
|
||||||
# If available, make llm api calls.
|
|
||||||
# Write result to redis cache (1 min ttl)
|
|
||||||
# """
|
|
||||||
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
# while True:
|
|
||||||
# job_checks = _job_check()
|
|
||||||
# future_to_job = {executor.submit(_llm_api_call, job): job for job in job_checks}
|
|
||||||
# for future in concurrent.futures.as_completed(future_to_job):
|
|
||||||
# job = future_to_job[future]
|
|
||||||
# try:
|
|
||||||
# result = future.result()
|
|
||||||
# except Exception as exc:
|
|
||||||
# print(f'{job} generated an exception: {exc}')
|
|
||||||
# else:
|
|
||||||
# _write_to_cache(job, result, ttl=1*60)
|
|
||||||
# time.sleep(1) # sleep 1 second to avoid overloading the server
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def _perform_health_checks():
|
|
||||||
# """
|
|
||||||
# Periodically performs health checks on the servers.
|
|
||||||
# Updates the list of healthy servers accordingly.
|
|
||||||
# """
|
|
||||||
# while True:
|
|
||||||
# healthy_deployments = _health_check()
|
|
||||||
# # Adjust the time interval based on your needs
|
|
||||||
# time.sleep(15)
|
|
||||||
|
|
||||||
# def _job_check():
|
|
||||||
# """
|
|
||||||
# Periodically performs job checks on the redis queue.
|
|
||||||
# Returns the list of available jobs - len(available_jobs) == len(healthy_endpoints),
|
|
||||||
# e.g. don't dequeue a gpt-3.5-turbo job if there's no healthy deployments left
|
|
||||||
# """
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# def _llm_api_call(**data):
|
|
||||||
# """
|
|
||||||
# Makes the litellm.completion() call with 3 retries
|
|
||||||
# """
|
|
||||||
# return litellm.completion(num_retries=3, **data)
|
|
||||||
|
|
||||||
# def _write_to_cache():
|
|
||||||
# """
|
|
||||||
# Writes the result to a redis cache in the form (key:job_id, value: <response_object>)
|
|
||||||
# """
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# def _health_check():
|
|
||||||
# """
|
|
||||||
# Performs a health check on the deployments
|
|
||||||
# Returns the list of healthy deployments
|
|
||||||
# """
|
|
||||||
# healthy_deployments = []
|
|
||||||
# for deployment in model_list:
|
|
||||||
# litellm_args = deployment["litellm_params"]
|
|
||||||
# try:
|
|
||||||
# start_time = time.time()
|
|
||||||
# litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond
|
|
||||||
# end_time = time.time()
|
|
||||||
# response_time = end_time - start_time
|
|
||||||
# logging.debug(f"response_time: {response_time}")
|
|
||||||
# healthy_deployments.append((deployment, response_time))
|
|
||||||
# healthy_deployments.sort(key=lambda x: x[1])
|
|
||||||
# except Exception as e:
|
|
||||||
# pass
|
|
||||||
# return healthy_deployments
|
|
|
@ -280,10 +280,15 @@ class Router:
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Trying to fallback b/w models")
|
self.print_verbose(f"Trying to fallback b/w models")
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
fallback_model_group = None
|
||||||
|
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if fallback_model_group is None:
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
for mg in fallback_model_group:
|
for mg in fallback_model_group:
|
||||||
"""
|
"""
|
||||||
Iterate through the model groups and try calling that deployment
|
Iterate through the model groups and try calling that deployment
|
||||||
|
@ -360,6 +365,7 @@ class Router:
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
model_group = kwargs.get("model")
|
model_group = kwargs.get("model")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.function_with_retries(*args, **kwargs)
|
response = self.function_with_retries(*args, **kwargs)
|
||||||
self.print_verbose(f'Response: {response}')
|
self.print_verbose(f'Response: {response}')
|
||||||
|
@ -368,36 +374,47 @@ class Router:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
self.print_verbose(f"An exception occurs{original_exception}")
|
self.print_verbose(f"An exception occurs{original_exception}")
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Trying to fallback b/w models")
|
self.print_verbose(f"Trying to fallback b/w models. Initial model group: {model_group}")
|
||||||
fallback_model_group = []
|
self.print_verbose(f"Type of exception: {type(e)}; error_message: {str(e)}")
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
for item in self.context_window_fallback_model_group: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
self.print_verbose(f"inside context window fallbacks: {self.context_window_fallbacks}")
|
||||||
|
fallback_model_group = None
|
||||||
|
for item in self.context_window_fallbacks: # [{"gpt-3.5-turbo": ["gpt-4"]}]
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if fallback_model_group is None:
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
for mg in fallback_model_group:
|
for mg in fallback_model_group:
|
||||||
"""
|
"""
|
||||||
Iterate through the model groups and try calling that deployment
|
Iterate through the model groups and try calling that deployment
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = mg
|
kwargs["model"] = mg
|
||||||
response = self.function_with_retries(*args, **kwargs)
|
response = self.function_with_fallbacks(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
|
self.print_verbose(f"inside model fallbacks: {self.fallbacks}")
|
||||||
|
fallback_model_group = None
|
||||||
for item in self.fallbacks:
|
for item in self.fallbacks:
|
||||||
if list(item.keys())[0] == model_group:
|
if list(item.keys())[0] == model_group:
|
||||||
fallback_model_group = item[model_group]
|
fallback_model_group = item[model_group]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if fallback_model_group is None:
|
||||||
|
raise original_exception
|
||||||
|
|
||||||
for mg in fallback_model_group:
|
for mg in fallback_model_group:
|
||||||
"""
|
"""
|
||||||
Iterate through the model groups and try calling that deployment
|
Iterate through the model groups and try calling that deployment
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = mg
|
kwargs["model"] = mg
|
||||||
response = self.function_with_retries(*args, **kwargs)
|
response = self.function_with_fallbacks(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -23,6 +23,17 @@ model_list = [
|
||||||
"tpm": 240000,
|
"tpm": 240000,
|
||||||
"rpm": 1800
|
"rpm": 1800
|
||||||
},
|
},
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
@ -42,21 +53,35 @@ model_list = [
|
||||||
},
|
},
|
||||||
"tpm": 1000000,
|
"tpm": 1000000,
|
||||||
"rpm": 9000
|
"rpm": 9000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
router = Router(model_list=model_list, fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}])
|
router = Router(model_list=model_list,
|
||||||
|
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||||
|
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
||||||
|
set_verbose=True)
|
||||||
|
|
||||||
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
|
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
|
||||||
|
|
||||||
def test_sync_fallbacks():
|
def test_sync_fallbacks():
|
||||||
try:
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
response = router.completion(**kwargs)
|
response = router.completion(**kwargs)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
test_sync_fallbacks()
|
||||||
|
|
||||||
def test_async_fallbacks():
|
def test_async_fallbacks():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
@ -74,4 +99,16 @@ def test_async_fallbacks():
|
||||||
|
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
# test_async_fallbacks()
|
# test_async_fallbacks()
|
||||||
|
|
||||||
|
def test_sync_context_window_fallbacks():
|
||||||
|
try:
|
||||||
|
sample_text = "Say error 50 times" * 10000
|
||||||
|
kwargs["model"] = "azure/gpt-3.5-turbo-context-fallback"
|
||||||
|
kwargs["messages"] = [{"role": "user", "content": sample_text}]
|
||||||
|
response = router.completion(**kwargs)
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
# test_sync_context_window_fallbacks()
|
Loading…
Add table
Add a link
Reference in a new issue