Merge pull request #3954 from BerriAI/litellm_simple_request_prioritization

feat(scheduler.py): add request prioritization scheduler
This commit is contained in:
Krish Dholakia 2024-05-31 23:29:09 -07:00 committed by GitHub
commit 8375e9621c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 612 additions and 149 deletions

View file

@ -142,6 +142,7 @@ from litellm.proxy.auth.auth_checks import (
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities
try:
from litellm._version import version
@ -397,6 +398,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
async_result = None
celery_app_conn = None
celery_fn = None # Redis Queue for handling requests
### SIMPLE QUEUE ###
simple_scheduler = Scheduler()
### DB WRITER ###
db_writer_client: Optional[HTTPHandler] = None
### logger ###
@ -3702,7 +3705,7 @@ def on_backoff(details):
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler
import json
### LOAD MASTER KEY ###
@ -3738,6 +3741,10 @@ async def startup_event():
## Error Tracking ##
error_tracking()
## Priority Workload Scheduler ##
if llm_router is not None:
simple_scheduler.update_variables(llm_router=llm_router)
## UPDATE SLACK ALERTING ##
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
@ -12076,118 +12083,7 @@ async def alerting_settings(
return return_val
# @router.post(
# "/alerting/update",
# description="Update the slack alerting settings. Persist value in db.",
# tags=["alerting"],
# dependencies=[Depends(user_api_key_auth)],
# include_in_schema=False,
# )
# async def alerting_update(
# data: SlackAlertingArgs,
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
# ):
# """Allows updating slack alerting values. Used by UI."""
# global prisma_client
# if prisma_client is None:
# raise HTTPException(
# status_code=400,
# detail={"error": CommonProxyErrors.db_not_connected_error.value},
# )
# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
# raise HTTPException(
# status_code=400,
# detail={"error": CommonProxyErrors.not_allowed_access.value},
# )
# ## get general settings from db
# db_general_settings = await prisma_client.db.litellm_config.find_first(
# where={"param_name": "general_settings"}
# )
# ### update value
# alerting_args_dict = {}
# if db_general_settings is None or db_general_settings.param_value is None:
# general_settings = {}
# alerting_args_dict = {}
# else:
# general_settings = dict(db_general_settings.param_value)
# _alerting_args_dict = general_settings.get("alerting_args", None)
# if _alerting_args_dict is not None and isinstance(_alerting_args_dict, dict):
# alerting_args_dict = _alerting_args_dict
# alerting_args_dict = data.model
# response = await prisma_client.db.litellm_config.upsert(
# where={"param_name": "general_settings"},
# data={
# "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
# "update": {"param_value": json.dumps(general_settings)}, # type: ignore
# },
# )
# return response
#### EXPERIMENTAL QUEUING ####
async def _litellm_chat_completions_worker(data, user_api_key_dict):
"""
worker to make litellm completions calls
"""
while True:
try:
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
verbose_proxy_logger.debug("_litellm_chat_completions_worker started")
### ROUTE THE REQUEST ###
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
response = await llm_router.acompletion(**data)
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
response = await llm_router.acompletion(
**data, specific_deployment=True
)
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
response = await llm_router.acompletion(**data)
else: # router is not set
response = await litellm.acompletion(**data)
verbose_proxy_logger.debug("final response: {response}")
return response
except HTTPException as e:
verbose_proxy_logger.debug(
f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}"
)
if (
e.status_code == 429
and "Max parallel request limit reached" in e.detail
):
verbose_proxy_logger.debug("Max parallel request limit reached!")
timeout = litellm._calculate_retry_after(
remaining_retries=3, max_retries=3, min_timeout=1
)
await asyncio.sleep(timeout)
else:
raise e
@router.post(
"/queue/chat/completions",
tags=["experimental"],
@ -12195,6 +12091,7 @@ async def _litellm_chat_completions_worker(data, user_api_key_dict):
)
async def async_queue_request(
request: Request,
fastapi_response: Response,
model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
@ -12260,12 +12157,47 @@ async def async_queue_request(
if user_api_base:
data["api_base"] = user_api_base
response = await asyncio.wait_for(
_litellm_chat_completions_worker(
data=data, user_api_key_dict=user_api_key_dict
),
timeout=litellm.request_timeout,
## FLOW ITEM ##
request_id = str(uuid.uuid4())
flow_item = FlowItem(
priority=data.pop("priority", DefaultPriorities.Medium.value),
request_id=request_id,
model_name=data["model"],
)
# [TODO] only allow premium users to set non default priorities
## ADD REQUEST TO QUEUE
response = await simple_scheduler.add_request(request=flow_item)
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
## POLL QUEUE
default_timeout = llm_router.timeout
end_time = time.time() + default_timeout
poll_interval = 0.03 # poll every 3ms
curr_time = time.time()
make_request = False
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
while curr_time < end_time:
make_request = await simple_scheduler.poll(
id=request_id, model_name=data["model"]
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()
if make_request:
response = await llm_router.acompletion(**data)
if (
"stream" in data and data["stream"] == True
@ -12279,6 +12211,7 @@ async def async_queue_request(
media_type="text/event-stream",
)
fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)})
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
@ -12301,6 +12234,19 @@ async def async_queue_request(
)
@router.get(
"/queue/info",
tags=["experimental"],
dependencies=[Depends(user_api_key_auth)],
)
async def queue_info(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> List:
"""Help user know the status of an item in the queue"""
return simple_scheduler.get_queue_status()
@router.get(
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
)