From 79287a7584d89e2511959f3e9ab38425b4673c10 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 31 May 2024 18:51:13 -0700 Subject: [PATCH] feat(scheduler.py): add request prioritization scheduler allow user to set priority for a request --- .../out/{404.html => 404/index.html} | 0 .../{model_hub.html => model_hub/index.html} | 0 litellm/proxy/_super_secret_config.yaml | 12 +- litellm/proxy/proxy_server.py | 180 ++++++------------ litellm/proxy/queue/scheduler.py | 129 +++++++++++++ litellm/tests/test_scheduler.py | 164 ++++++++++++++++ litellm/types/router.py | 2 + tests/test_openai_endpoints.py | 30 +++ 8 files changed, 394 insertions(+), 123 deletions(-) rename litellm/proxy/_experimental/out/{404.html => 404/index.html} (100%) rename litellm/proxy/_experimental/out/{model_hub.html => model_hub/index.html} (100%) create mode 100644 litellm/proxy/queue/scheduler.py create mode 100644 litellm/tests/test_scheduler.py diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404/index.html similarity index 100% rename from litellm/proxy/_experimental/out/404.html rename to litellm/proxy/_experimental/out/404/index.html diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model_hub.html rename to litellm/proxy/_experimental/out/model_hub/index.html diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 00efca522..97408262c 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -5,12 +5,12 @@ model_list: model: openai/my-fake-model rpm: 800 model_name: gpt-3.5-turbo-fake-model -- litellm_params: - api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ - api_key: os.environ/AZURE_EUROPE_API_KEY - model: azure/gpt-35-turbo - rpm: 10 - model_name: gpt-3.5-turbo-fake-model +# - litellm_params: +# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ +# api_key: os.environ/AZURE_EUROPE_API_KEY +# model: azure/gpt-35-turbo +# rpm: 10 +# model_name: gpt-3.5-turbo-fake-model - litellm_params: api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_key: os.environ/AZURE_API_KEY diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b6e27cba6..3a0efd653 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -141,6 +141,7 @@ from litellm.proxy.auth.auth_checks import ( from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.exceptions import RejectedRequestError from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting +from litellm.proxy.queue.scheduler import Scheduler, FlowItem, DefaultPriorities try: from litellm._version import version @@ -395,6 +396,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) async_result = None celery_app_conn = None celery_fn = None # Redis Queue for handling requests +### SIMPLE QUEUE ### +scheduler = Scheduler() ### DB WRITER ### db_writer_client: Optional[HTTPHandler] = None ### logger ### @@ -3655,7 +3658,7 @@ def on_backoff(details): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, scheduler import json ### LOAD MASTER KEY ### @@ -3691,6 +3694,10 @@ async def startup_event(): ## Error Tracking ## error_tracking() + ## Priority Workload Scheduler ## + if llm_router is not None: + scheduler.update_variables(llm_router=llm_router) + ## UPDATE SLACK ALERTING ## proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router) @@ -11219,118 +11226,7 @@ async def alerting_settings( return return_val -# @router.post( -# "/alerting/update", -# description="Update the slack alerting settings. Persist value in db.", -# tags=["alerting"], -# dependencies=[Depends(user_api_key_auth)], -# include_in_schema=False, -# ) -# async def alerting_update( -# data: SlackAlertingArgs, -# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -# ): -# """Allows updating slack alerting values. Used by UI.""" -# global prisma_client -# if prisma_client is None: -# raise HTTPException( -# status_code=400, -# detail={"error": CommonProxyErrors.db_not_connected_error.value}, -# ) - -# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: -# raise HTTPException( -# status_code=400, -# detail={"error": CommonProxyErrors.not_allowed_access.value}, -# ) - -# ## get general settings from db -# db_general_settings = await prisma_client.db.litellm_config.find_first( -# where={"param_name": "general_settings"} -# ) -# ### update value - -# alerting_args_dict = {} -# if db_general_settings is None or db_general_settings.param_value is None: -# general_settings = {} -# alerting_args_dict = {} -# else: -# general_settings = dict(db_general_settings.param_value) -# _alerting_args_dict = general_settings.get("alerting_args", None) -# if _alerting_args_dict is not None and isinstance(_alerting_args_dict, dict): -# alerting_args_dict = _alerting_args_dict - - -# alerting_args_dict = data.model - -# response = await prisma_client.db.litellm_config.upsert( -# where={"param_name": "general_settings"}, -# data={ -# "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore -# "update": {"param_value": json.dumps(general_settings)}, # type: ignore -# }, -# ) - -# return response - - #### EXPERIMENTAL QUEUING #### -async def _litellm_chat_completions_worker(data, user_api_key_dict): - """ - worker to make litellm completions calls - """ - while True: - try: - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, data=data, call_type="completion" - ) - - verbose_proxy_logger.debug("_litellm_chat_completions_worker started") - ### ROUTE THE REQUEST ### - router_model_names = ( - [m["model_name"] for m in llm_model_list] - if llm_model_list is not None - else [] - ) - if ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - response = await llm_router.acompletion(**data) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.acompletion( - **data, specific_deployment=True - ) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.acompletion(**data) - else: # router is not set - response = await litellm.acompletion(**data) - - verbose_proxy_logger.debug("final response: {response}") - return response - except HTTPException as e: - verbose_proxy_logger.debug( - f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}" - ) - if ( - e.status_code == 429 - and "Max parallel request limit reached" in e.detail - ): - verbose_proxy_logger.debug("Max parallel request limit reached!") - timeout = litellm._calculate_retry_after( - remaining_retries=3, max_retries=3, min_timeout=1 - ) - await asyncio.sleep(timeout) - else: - raise e - - @router.post( "/queue/chat/completions", tags=["experimental"], @@ -11338,6 +11234,7 @@ async def _litellm_chat_completions_worker(data, user_api_key_dict): ) async def async_queue_request( request: Request, + fastapi_response: Response, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): @@ -11403,12 +11300,47 @@ async def async_queue_request( if user_api_base: data["api_base"] = user_api_base - response = await asyncio.wait_for( - _litellm_chat_completions_worker( - data=data, user_api_key_dict=user_api_key_dict - ), - timeout=litellm.request_timeout, + ## FLOW ITEM ## + request_id = str(uuid.uuid4()) + flow_item = FlowItem( + priority=data.pop("priority", DefaultPriorities.Medium.value), + request_id=request_id, + model_group=data["model"], ) + # [TODO] only allow premium users to set non default priorities + + ## ADD REQUEST TO QUEUE + response = await scheduler.add_request(request=flow_item) + + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + ## POLL QUEUE + default_timeout = llm_router.timeout + end_time = time.time() + default_timeout + poll_interval = 0.03 # poll every 3ms + curr_time = time.time() + + make_request = False + + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + + while curr_time < end_time: + make_request = await scheduler.poll( + id=request_id, model_group=data["model"] + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + response = await llm_router.acompletion(**data) if ( "stream" in data and data["stream"] == True @@ -11422,6 +11354,7 @@ async def async_queue_request( media_type="text/event-stream", ) + fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)}) return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( @@ -11444,6 +11377,19 @@ async def async_queue_request( ) +@router.get( + "/queue/info", + tags=["experimental"], + dependencies=[Depends(user_api_key_auth)], +) +async def queue_info( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> List: + """Help user know the status of an item in the queue""" + return scheduler.get_queue_status() + + @router.get( "/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"] ) diff --git a/litellm/proxy/queue/scheduler.py b/litellm/proxy/queue/scheduler.py new file mode 100644 index 000000000..79a282d96 --- /dev/null +++ b/litellm/proxy/queue/scheduler.py @@ -0,0 +1,129 @@ +import heapq, time +from pydantic import BaseModel +from typing import Optional +import enum +from litellm.caching import DualCache +from litellm import Router +from litellm import print_verbose + + +class SchedulerCacheKeys(enum.Enum): + queue = "scheduler:queue" + + +class DefaultPriorities(enum.Enum): + High = 0 + Medium = 128 + Low = 255 + + +class FlowItem(BaseModel): + priority: int # Priority between 0 and 255 + request_id: str + model_group: str + + +class Scheduler: + cache: DualCache + llm_router: Optional[Router] = None + + def __init__(self): + self.queue = [] + self.cache = DualCache() + + def update_variables(self, llm_router: Router, cache: Optional[DualCache] = None): + self.llm_router = llm_router + if cache is not None: + self.cache = cache + + async def add_request(self, request: FlowItem): + # We use the priority directly, as lower values indicate higher priority + # get the queue + queue = await self.get_queue(model_group=request.model_group) + # update the queue + heapq.heappush(queue, (request.priority, request.request_id)) + + # save the queue + await self.save_queue(queue=queue, model_group=request.model_group) + + async def poll(self, id: str, model_group: str) -> bool: + """Return if the id is at the top of the queue and if the token bucket allows processing""" + queue = await self.get_queue(model_group=model_group) + if not queue or not self.llm_router: + raise Exception( + "Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( + queue, self.llm_router + ) + ) + + # ------------ + # Setup values + # ------------ + _healthy_deployments = await self.llm_router._async_get_healthy_deployments( + model=model_group + ) + + print_verbose(f"len(_healthy_deployments): {len(_healthy_deployments)}") + if len(_healthy_deployments) == 0: + return False + + print_verbose(f"queue: {queue}, seeking id={id}") + # Check if the id is at the top of the heap + if queue[0][1] == id: + # Remove the item from the queue + heapq.heappop(queue) + print_verbose(f"Popped id: {id}") + return True + + return False + + async def peek(self, id: str, model_group: str) -> bool: + """Return if the id is at the top of the queue. Don't pop the value from heap.""" + queue = await self.get_queue(model_group=model_group) + if not queue or not self.llm_router: + raise Exception( + "Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( + queue, self.llm_router + ) + ) + + # ------------ + # Setup values + # ------------ + _healthy_deployments = await self.llm_router._async_get_healthy_deployments( + model=model_group + ) + if len(_healthy_deployments) == 0: + return False + + # Check if the id is at the top of the heap + if queue[0][1] == id: + return True + + return False + + def get_queue_status(self): + """Get the status of items in the queue""" + return self.queue + + async def get_queue(self, model_group: str) -> list: + """ + Return a queue for that specific model group + """ + if self.cache is not None: + _cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_group) + response = await self.cache.async_get_cache(key=_cache_key) + if response is None or not isinstance(response, list): + return [] + elif isinstance(response, list): + return response + return self.queue + + async def save_queue(self, queue: list, model_group: str) -> None: + """ + Save the updated queue of the model group + """ + if self.cache is not None: + _cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_group) + await self.cache.async_set_cache(key=_cache_key, value=queue) + return None diff --git a/litellm/tests/test_scheduler.py b/litellm/tests/test_scheduler.py new file mode 100644 index 000000000..1177bc12c --- /dev/null +++ b/litellm/tests/test_scheduler.py @@ -0,0 +1,164 @@ +# What is this? +## Unit tests for the Scheduler.py (workload prioritization scheduler) + +import sys, os, time, openai, uuid +import traceback, asyncio +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router +from litellm.proxy.queue.scheduler import FlowItem, Scheduler + + +@pytest.mark.asyncio +async def test_scheduler_diff_model_groups(): + """ + Assert 2 requests to 2 diff model groups are top of their respective queue's + """ + scheduler = Scheduler() + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, + ] + ) + + scheduler.update_variables(llm_router=router) + + item1 = FlowItem(priority=0, request_id="10", model_group="gpt-3.5-turbo") + item2 = FlowItem(priority=0, request_id="11", model_group="gpt-4") + await scheduler.add_request(item1) + await scheduler.add_request(item2) + + assert await scheduler.poll(id="10", model_group="gpt-3.5-turbo") == True + assert await scheduler.poll(id="11", model_group="gpt-4") == True + + +@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)]) +@pytest.mark.asyncio +async def test_scheduler_prioritized_requests(p0, p1): + """ + 2 requests for same model group + """ + scheduler = Scheduler() + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, + ] + ) + + scheduler.update_variables(llm_router=router) + + item1 = FlowItem(priority=p0, request_id="10", model_group="gpt-3.5-turbo") + item2 = FlowItem(priority=p1, request_id="11", model_group="gpt-3.5-turbo") + await scheduler.add_request(item1) + await scheduler.add_request(item2) + + if p0 == 0: + assert await scheduler.peek(id="10", model_group="gpt-3.5-turbo") == True + assert await scheduler.peek(id="11", model_group="gpt-3.5-turbo") == False + else: + assert await scheduler.peek(id="11", model_group="gpt-3.5-turbo") == True + assert await scheduler.peek(id="10", model_group="gpt-3.5-turbo") == False + + +@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)]) +@pytest.mark.asyncio +async def test_scheduler_prioritized_requests_mock_response(p0, p1): + """ + 2 requests for same model group + """ + scheduler = Scheduler() + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "Hello world this is Macintosh!", + }, + }, + ], + timeout=2, + ) + + scheduler.update_variables(llm_router=router) + + async def _make_prioritized_call(flow_item: FlowItem): + ## POLL QUEUE + default_timeout = router.timeout + end_time = time.time() + default_timeout + poll_interval = 0.03 # poll every 3ms + curr_time = time.time() + + make_request = False + + if router is None: + raise Exception("No llm router value") + + while curr_time < end_time: + make_request = await scheduler.poll( + id=flow_item.request_id, model_group=flow_item.model_group + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + _response = await router.acompletion( + model=flow_item.model_group, + messages=[{"role": "user", "content": "Hey!"}], + ) + + return flow_item.priority, flow_item.request_id, time.time() + + raise Exception("didn't make request") + + tasks = [] + + item = FlowItem( + priority=p0, request_id=str(uuid.uuid4()), model_group="gpt-3.5-turbo" + ) + await scheduler.add_request(request=item) + tasks.append(_make_prioritized_call(flow_item=item)) + + item = FlowItem( + priority=p1, request_id=str(uuid.uuid4()), model_group="gpt-3.5-turbo" + ) + await scheduler.add_request(request=item) + tasks.append(_make_prioritized_call(flow_item=item)) + + # Running the tasks and getting responses in order of completion + completed_responses = [] + for task in asyncio.as_completed(tasks): + result = await task + completed_responses.append(result) + print(f"Received response: {result}") + + print(f"responses: {completed_responses}") + assert ( + completed_responses[0][0] == 0 + ) # assert higher priority request got done first + assert ( + completed_responses[0][2] < completed_responses[1][2] + ), "1st response time={}, 2nd response time={}".format( + completed_responses[0][1], completed_responses[1][1] + ) # assert higher priority request got done first diff --git a/litellm/types/router.py b/litellm/types/router.py index 75e792f4c..f131c2929 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -326,6 +326,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): output_cost_per_token: Optional[float] input_cost_per_second: Optional[float] output_cost_per_second: Optional[float] + ## MOCK RESPONSES ## + mock_response: Optional[str] class DeploymentTypedDict(TypedDict): diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 43dcae3cd..83d387ffb 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -103,6 +103,36 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"): return await response.json() +async def queue_chat_completion( + session, key, priority: int, model: Union[str, List] = "gpt-4" +): + url = "http://0.0.0.0:4000/queue/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + "priority": priority, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return response.raw_headers + + async def chat_completion_with_headers(session, key, model="gpt-4"): url = "http://0.0.0.0:4000/chat/completions" headers = {