diff --git a/docs/my-website/docs/scheduler.md b/docs/my-website/docs/scheduler.md new file mode 100644 index 000000000..347406ade --- /dev/null +++ b/docs/my-website/docs/scheduler.md @@ -0,0 +1,141 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# [BETA] Request Prioritization + +:::info + +Beta feature. Use for testing only. + +[Help us improve this](https://github.com/BerriAI/litellm/issues) +::: + +Prioritize LLM API requests in high-traffic. + +- Add request to priority queue +- Poll queue, to check if request can be made. Returns 'True': + * if there's healthy deployments + * OR if request is at top of queue +- Priority - The lower the number, the higher the priority: + * e.g. `priority=0` > `priority=2000` + +## Quick Start + +```python +from litellm import Scheduler, FlowItem, Router + +scheduler = Scheduler() + +router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "Hello world this is Macintosh!", # fakes the LLM API call + "rpm": 1, + }, + }, + ], + timeout=2, # timeout request if takes > 2s + routing_strategy="usage-based-routing-v2", +) + +scheduler.update_variables(llm_router=router) + +### 🚨 IMPORTANT ### + +item = FlowItem( + priority=0, # 👈 SET PRIORITY FOR REQUEST + request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID + model_name="gpt-3.5-turbo" # 👈 SAME as 'Router' +) + +### [fin] IMPORTANT ### + +## ADDS REQUEST TO QUEUE ## +await scheduler.add_request(request=item) + +## POLL QUEUE +default_timeout = router.timeout +end_time = time.time() + default_timeout +poll_interval = 0.03 # poll every 3ms +curr_time = time.time() + +make_request = False + +while curr_time < end_time: + make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue + id=item.request_id, model_name=item.model_name + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + +if make_request: + try: + _response = await router.acompletion( + model=item.model_name, + messages=[{"role": "user", "content": "Hey!"}], + ) + except Exception as e: + print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred")) + + print("{}, {}, {}".format(item.priority, item.request_id, time.time())) + +print("didn't make request") +``` + +## LiteLLM Proxy + +To prioritize requests on LiteLLM Proxy call our beta openai-compatible `http://localhost:4000/queue` endpoint. + + + + +```curl +curl -X POST 'http://localhost:4000/queue/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-D '{ + "model": "gpt-3.5-turbo-fake-model", + "messages": [ + { + "role": "user", + "content": "what is the meaning of the universe? 1234" + }], + "priority": 0 👈 SET VALUE HERE +}' +``` + + + + +```python +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:4000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } + ], + extra_body={ + "priority": 0 👈 SET VALUE HERE + } +) + +print(response) +``` + + + \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 29095d41f..d5c6f8bc2 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -164,6 +164,7 @@ const sidebars = { }, "proxy/custom_pricing", "routing", + "scheduler", "rules", "set_keys", "budget_manager", diff --git a/litellm/__init__.py b/litellm/__init__.py index fe9f025bb..35f931788 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -808,3 +808,4 @@ from .proxy.proxy_cli import run_server from .router import Router from .assistants.main import * from .batches.main import * +from .scheduler import * diff --git a/litellm/main.py b/litellm/main.py index 37565ff50..ad14ec92b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -431,6 +431,10 @@ def mock_completion( model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) + time_delay = kwargs.get("mock_delay", None) + if time_delay is not None: + time.sleep(time_delay) + model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, @@ -881,6 +885,7 @@ def completion( mock_response=mock_response, logging=logging, acompletion=acompletion, + mock_delay=kwargs.get("mock_delay", None), ) if custom_llm_provider == "azure": # azure configs diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 00efca522..97408262c 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -5,12 +5,12 @@ model_list: model: openai/my-fake-model rpm: 800 model_name: gpt-3.5-turbo-fake-model -- litellm_params: - api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ - api_key: os.environ/AZURE_EUROPE_API_KEY - model: azure/gpt-35-turbo - rpm: 10 - model_name: gpt-3.5-turbo-fake-model +# - litellm_params: +# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ +# api_key: os.environ/AZURE_EUROPE_API_KEY +# model: azure/gpt-35-turbo +# rpm: 10 +# model_name: gpt-3.5-turbo-fake-model - litellm_params: api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_key: os.environ/AZURE_API_KEY diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1ff38578a..3d0b4b587 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -142,6 +142,7 @@ from litellm.proxy.auth.auth_checks import ( from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.exceptions import RejectedRequestError from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting +from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities try: from litellm._version import version @@ -397,6 +398,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) async_result = None celery_app_conn = None celery_fn = None # Redis Queue for handling requests +### SIMPLE QUEUE ### +simple_scheduler = Scheduler() ### DB WRITER ### db_writer_client: Optional[HTTPHandler] = None ### logger ### @@ -3702,7 +3705,7 @@ def on_backoff(details): @router.on_event("startup") async def startup_event(): - global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db + global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler import json ### LOAD MASTER KEY ### @@ -3738,6 +3741,10 @@ async def startup_event(): ## Error Tracking ## error_tracking() + ## Priority Workload Scheduler ## + if llm_router is not None: + simple_scheduler.update_variables(llm_router=llm_router) + ## UPDATE SLACK ALERTING ## proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router) @@ -12076,118 +12083,7 @@ async def alerting_settings( return return_val -# @router.post( -# "/alerting/update", -# description="Update the slack alerting settings. Persist value in db.", -# tags=["alerting"], -# dependencies=[Depends(user_api_key_auth)], -# include_in_schema=False, -# ) -# async def alerting_update( -# data: SlackAlertingArgs, -# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -# ): -# """Allows updating slack alerting values. Used by UI.""" -# global prisma_client -# if prisma_client is None: -# raise HTTPException( -# status_code=400, -# detail={"error": CommonProxyErrors.db_not_connected_error.value}, -# ) - -# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: -# raise HTTPException( -# status_code=400, -# detail={"error": CommonProxyErrors.not_allowed_access.value}, -# ) - -# ## get general settings from db -# db_general_settings = await prisma_client.db.litellm_config.find_first( -# where={"param_name": "general_settings"} -# ) -# ### update value - -# alerting_args_dict = {} -# if db_general_settings is None or db_general_settings.param_value is None: -# general_settings = {} -# alerting_args_dict = {} -# else: -# general_settings = dict(db_general_settings.param_value) -# _alerting_args_dict = general_settings.get("alerting_args", None) -# if _alerting_args_dict is not None and isinstance(_alerting_args_dict, dict): -# alerting_args_dict = _alerting_args_dict - - -# alerting_args_dict = data.model - -# response = await prisma_client.db.litellm_config.upsert( -# where={"param_name": "general_settings"}, -# data={ -# "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore -# "update": {"param_value": json.dumps(general_settings)}, # type: ignore -# }, -# ) - -# return response - - #### EXPERIMENTAL QUEUING #### -async def _litellm_chat_completions_worker(data, user_api_key_dict): - """ - worker to make litellm completions calls - """ - while True: - try: - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( - user_api_key_dict=user_api_key_dict, data=data, call_type="completion" - ) - - verbose_proxy_logger.debug("_litellm_chat_completions_worker started") - ### ROUTE THE REQUEST ### - router_model_names = ( - [m["model_name"] for m in llm_model_list] - if llm_model_list is not None - else [] - ) - if ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - response = await llm_router.acompletion(**data) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.acompletion( - **data, specific_deployment=True - ) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.acompletion(**data) - else: # router is not set - response = await litellm.acompletion(**data) - - verbose_proxy_logger.debug("final response: {response}") - return response - except HTTPException as e: - verbose_proxy_logger.debug( - f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}" - ) - if ( - e.status_code == 429 - and "Max parallel request limit reached" in e.detail - ): - verbose_proxy_logger.debug("Max parallel request limit reached!") - timeout = litellm._calculate_retry_after( - remaining_retries=3, max_retries=3, min_timeout=1 - ) - await asyncio.sleep(timeout) - else: - raise e - - @router.post( "/queue/chat/completions", tags=["experimental"], @@ -12195,6 +12091,7 @@ async def _litellm_chat_completions_worker(data, user_api_key_dict): ) async def async_queue_request( request: Request, + fastapi_response: Response, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): @@ -12260,12 +12157,47 @@ async def async_queue_request( if user_api_base: data["api_base"] = user_api_base - response = await asyncio.wait_for( - _litellm_chat_completions_worker( - data=data, user_api_key_dict=user_api_key_dict - ), - timeout=litellm.request_timeout, + ## FLOW ITEM ## + request_id = str(uuid.uuid4()) + flow_item = FlowItem( + priority=data.pop("priority", DefaultPriorities.Medium.value), + request_id=request_id, + model_name=data["model"], ) + # [TODO] only allow premium users to set non default priorities + + ## ADD REQUEST TO QUEUE + response = await simple_scheduler.add_request(request=flow_item) + + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + ## POLL QUEUE + default_timeout = llm_router.timeout + end_time = time.time() + default_timeout + poll_interval = 0.03 # poll every 3ms + curr_time = time.time() + + make_request = False + + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + + while curr_time < end_time: + make_request = await simple_scheduler.poll( + id=request_id, model_name=data["model"] + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + response = await llm_router.acompletion(**data) if ( "stream" in data and data["stream"] == True @@ -12279,6 +12211,7 @@ async def async_queue_request( media_type="text/event-stream", ) + fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)}) return response except Exception as e: await proxy_logging_obj.post_call_failure_hook( @@ -12301,6 +12234,19 @@ async def async_queue_request( ) +@router.get( + "/queue/info", + tags=["experimental"], + dependencies=[Depends(user_api_key_auth)], +) +async def queue_info( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> List: + """Help user know the status of an item in the queue""" + return simple_scheduler.get_queue_status() + + @router.get( "/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"] ) diff --git a/litellm/router.py b/litellm/router.py index e8ca2156d..88eb54a04 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -641,7 +641,6 @@ class Router: kwargs=kwargs, client_type="max_parallel_requests", ) - if rpm_semaphore is not None and isinstance( rpm_semaphore, asyncio.Semaphore ): @@ -1987,6 +1986,7 @@ class Router: error=e, healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, + regular_fallbacks=fallbacks, ) # decides how long to sleep before retry @@ -1996,7 +1996,6 @@ class Router: num_retries=num_retries, healthy_deployments=_healthy_deployments, ) - # sleeps for the length of the timeout await asyncio.sleep(_timeout) @@ -2041,6 +2040,7 @@ class Router: healthy_deployments=_healthy_deployments, ) await asyncio.sleep(_timeout) + try: cooldown_deployments = await self._async_get_cooldown_deployments() original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}" @@ -2053,6 +2053,7 @@ class Router: error: Exception, healthy_deployments: Optional[List] = None, context_window_fallbacks: Optional[List] = None, + regular_fallbacks: Optional[List] = None, ): """ 1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None @@ -2069,7 +2070,7 @@ class Router: ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error if ( isinstance(error, litellm.ContextWindowExceededError) - and context_window_fallbacks is None + and context_window_fallbacks is not None ): raise error @@ -2077,7 +2078,11 @@ class Router: if isinstance(error, openai.RateLimitError) or isinstance( error, openai.AuthenticationError ): - if _num_healthy_deployments <= 0: + if ( + _num_healthy_deployments <= 0 + and regular_fallbacks is not None + and len(regular_fallbacks) > 0 + ): raise error return True @@ -2252,6 +2257,7 @@ class Router: error=e, healthy_deployments=_healthy_deployments, context_window_fallbacks=context_window_fallbacks, + regular_fallbacks=fallbacks, ) # decides how long to sleep before retry @@ -2460,7 +2466,7 @@ class Router: the exception is not one that should be immediately retried (e.g. 401) """ - args = locals() + if deployment is None: return @@ -2631,7 +2637,17 @@ class Router: """ for _callback in litellm.callbacks: if isinstance(_callback, CustomLogger): - response = await _callback.async_pre_call_check(deployment) + try: + response = await _callback.async_pre_call_check(deployment) + except litellm.RateLimitError as e: + self._set_cooldown_deployments( + exception_status=e.status_code, + deployment=deployment["model_info"]["id"], + time_to_cooldown=self.cooldown_time, + ) + raise e + except Exception as e: + raise e def set_client(self, model: dict): """ diff --git a/litellm/scheduler.py b/litellm/scheduler.py new file mode 100644 index 000000000..3bbd3916e --- /dev/null +++ b/litellm/scheduler.py @@ -0,0 +1,139 @@ +import heapq, time +from pydantic import BaseModel +from typing import Optional +import enum +from litellm.caching import DualCache +from litellm import Router +from litellm import print_verbose + + +class SchedulerCacheKeys(enum.Enum): + queue = "scheduler:queue" + + +class DefaultPriorities(enum.Enum): + High = 0 + Medium = 128 + Low = 255 + + +class FlowItem(BaseModel): + priority: int # Priority between 0 and 255 + request_id: str + model_name: str + + +class Scheduler: + cache: DualCache + llm_router: Optional[Router] = None + + def __init__(self): + self.queue = [] + self.cache = DualCache() + + def update_variables(self, llm_router: Router, cache: Optional[DualCache] = None): + self.llm_router = llm_router + if cache is not None: + self.cache = cache + + async def add_request(self, request: FlowItem): + # We use the priority directly, as lower values indicate higher priority + # get the queue + queue = await self.get_queue(model_name=request.model_name) + # update the queue + heapq.heappush(queue, (request.priority, request.request_id)) + + # save the queue + await self.save_queue(queue=queue, model_name=request.model_name) + + async def poll(self, id: str, model_name: str) -> bool: + """ + Return if request can be processed. + + Returns: + - True: + * If healthy deployments are available + * OR If request at the top of queue + - False: + * If no healthy deployments available + * AND request not at the top of queue + """ + queue = await self.get_queue(model_name=model_name) + if not queue or not self.llm_router: + raise Exception( + "Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( + queue, self.llm_router + ) + ) + + # ------------ + # Setup values + # ------------ + _healthy_deployments = await self.llm_router._async_get_healthy_deployments( + model=model_name + ) + + print_verbose(f"len(_healthy_deployments): {len(_healthy_deployments)}") + if len(_healthy_deployments) == 0: + print_verbose(f"queue: {queue}, seeking id={id}") + # Check if the id is at the top of the heap + if queue[0][1] == id: + # Remove the item from the queue + heapq.heappop(queue) + print_verbose(f"Popped id: {id}") + return True + else: + return False + + return True + + async def peek(self, id: str, model_name: str) -> bool: + """Return if the id is at the top of the queue. Don't pop the value from heap.""" + queue = await self.get_queue(model_name=model_name) + if not queue or not self.llm_router: + raise Exception( + "Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( + queue, self.llm_router + ) + ) + + # ------------ + # Setup values + # ------------ + _healthy_deployments = await self.llm_router._async_get_healthy_deployments( + model=model_name + ) + if len(_healthy_deployments) == 0: + return False + + # Check if the id is at the top of the heap + if queue[0][1] == id: + return True + + return False + + def get_queue_status(self): + """Get the status of items in the queue""" + return self.queue + + async def get_queue(self, model_name: str) -> list: + """ + Return a queue for that specific model group + """ + if self.cache is not None: + _cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name) + response = await self.cache.async_get_cache(key=_cache_key) + if response is None or not isinstance(response, list): + return [] + elif isinstance(response, list): + return response + return self.queue + + async def save_queue(self, queue: list, model_name: str) -> None: + """ + Save the updated queue of the model group + """ + if self.cache is not None: + _cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name) + await self.cache.async_set_cache(key=_cache_key, value=queue) + return None diff --git a/litellm/tests/test_router_retries.py b/litellm/tests/test_router_retries.py index 7273fd6e9..1c0c3e49f 100644 --- a/litellm/tests/test_router_retries.py +++ b/litellm/tests/test_router_retries.py @@ -134,7 +134,7 @@ async def test_router_retry_policy(error_type): ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0 ) - router = litellm.Router( + router = Router( model_list=[ { "model_name": "gpt-3.5-turbo", # openai model name @@ -334,13 +334,13 @@ def test_retry_rate_limit_error_with_healthy_deployments(): ) -def test_do_not_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployments(): +def test_do_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployments(): """ - Test 2. It SHOULD NOT Retry, when healthy_deployments is [] and fallbacks is None + Test 2. It SHOULD Retry, when healthy_deployments is [] and fallbacks is None """ healthy_deployments = [] - router = litellm.Router( + router = Router( model_list=[ { "model_name": "gpt-3.5-turbo", @@ -359,14 +359,14 @@ def test_do_not_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployme response = router.should_retry_this_error( error=rate_limit_error, healthy_deployments=healthy_deployments ) - assert response != True, "Should have raised RateLimitError" - except openai.RateLimitError: - pass + assert response == True + except Exception as e: + pytest.fail("Should not have failed this error - {}".format(str(e))) def test_raise_context_window_exceeded_error(): """ - Retry Context Window Exceeded Error, when context_window_fallbacks is not None + Trigger Context Window fallback, when context_window_fallbacks is not None """ context_window_error = litellm.ContextWindowExceededError( message="Context window exceeded", @@ -379,7 +379,7 @@ def test_raise_context_window_exceeded_error(): ) context_window_fallbacks = [{"gpt-3.5-turbo": ["azure/chatgpt-v-2"]}] - router = litellm.Router( + router = Router( model_list=[ { "model_name": "gpt-3.5-turbo", @@ -393,14 +393,17 @@ def test_raise_context_window_exceeded_error(): ] ) - response = router.should_retry_this_error( - error=context_window_error, - healthy_deployments=None, - context_window_fallbacks=context_window_fallbacks, - ) - assert ( - response == True - ), "Should not have raised exception since we have context window fallbacks" + try: + response = router.should_retry_this_error( + error=context_window_error, + healthy_deployments=None, + context_window_fallbacks=context_window_fallbacks, + ) + pytest.fail( + "Expected to raise context window exceeded error -> trigger fallback" + ) + except Exception as e: + pass def test_raise_context_window_exceeded_error_no_retry(): @@ -418,7 +421,7 @@ def test_raise_context_window_exceeded_error_no_retry(): ) context_window_fallbacks = None - router = litellm.Router( + router = Router( model_list=[ { "model_name": "gpt-3.5-turbo", @@ -439,8 +442,8 @@ def test_raise_context_window_exceeded_error_no_retry(): context_window_fallbacks=context_window_fallbacks, ) assert ( - response != True - ), "Should have raised exception since we do not have context window fallbacks" + response == True + ), "Should not have raised exception since we do not have context window fallbacks" except litellm.ContextWindowExceededError: pass diff --git a/litellm/tests/test_scheduler.py b/litellm/tests/test_scheduler.py new file mode 100644 index 000000000..bba06d587 --- /dev/null +++ b/litellm/tests/test_scheduler.py @@ -0,0 +1,179 @@ +# What is this? +## Unit tests for the Scheduler.py (workload prioritization scheduler) + +import sys, os, time, openai, uuid +import traceback, asyncio +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from litellm import Router +from litellm.scheduler import FlowItem, Scheduler + + +@pytest.mark.asyncio +async def test_scheduler_diff_model_names(): + """ + Assert 2 requests to 2 diff model groups are top of their respective queue's + """ + scheduler = Scheduler() + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, + ] + ) + + scheduler.update_variables(llm_router=router) + + item1 = FlowItem(priority=0, request_id="10", model_name="gpt-3.5-turbo") + item2 = FlowItem(priority=0, request_id="11", model_name="gpt-4") + await scheduler.add_request(item1) + await scheduler.add_request(item2) + + assert await scheduler.poll(id="10", model_name="gpt-3.5-turbo") == True + assert await scheduler.poll(id="11", model_name="gpt-4") == True + + +@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)]) +@pytest.mark.asyncio +async def test_scheduler_prioritized_requests(p0, p1): + """ + 2 requests for same model group + """ + scheduler = Scheduler() + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + }, + {"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}}, + ] + ) + + scheduler.update_variables(llm_router=router) + + item1 = FlowItem(priority=p0, request_id="10", model_name="gpt-3.5-turbo") + item2 = FlowItem(priority=p1, request_id="11", model_name="gpt-3.5-turbo") + await scheduler.add_request(item1) + await scheduler.add_request(item2) + + if p0 == 0: + assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == True + assert await scheduler.peek(id="11", model_name="gpt-3.5-turbo") == False + else: + assert await scheduler.peek(id="11", model_name="gpt-3.5-turbo") == True + assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False + + +@pytest.mark.parametrize("p0, p1", [(0, 1)]) # (0, 0), (1, 0) +@pytest.mark.asyncio +async def test_scheduler_prioritized_requests_mock_response(p0, p1): + """ + 2 requests for same model group + + if model is at rate limit, ensure the higher priority request gets done first + """ + scheduler = Scheduler() + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "Hello world this is Macintosh!", + "rpm": 1, + }, + }, + ], + timeout=10, + num_retries=3, + cooldown_time=5, + routing_strategy="usage-based-routing-v2", + ) + + scheduler.update_variables(llm_router=router) + + await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey!"}], + ) + + async def _make_prioritized_call(flow_item: FlowItem): + ## POLL QUEUE + default_timeout = router.timeout + end_time = time.time() + default_timeout + poll_interval = 0.03 # poll every 3ms + curr_time = time.time() + + make_request = False + + if router is None: + raise Exception("No llm router value") + + while curr_time < end_time: + make_request = await scheduler.poll( + id=flow_item.request_id, model_name=flow_item.model_name + ) + print(f"make_request={make_request}, priority={flow_item.priority}") + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + try: + _response = await router.acompletion( + model=flow_item.model_name, + messages=[{"role": "user", "content": "Hey!"}], + ) + except Exception as e: + print("Received error - {}".format(str(e))) + return flow_item.priority, flow_item.request_id, time.time() + + return flow_item.priority, flow_item.request_id, time.time() + + raise Exception("didn't make request") + + tasks = [] + + item = FlowItem( + priority=p0, request_id=str(uuid.uuid4()), model_name="gpt-3.5-turbo" + ) + await scheduler.add_request(request=item) + tasks.append(_make_prioritized_call(flow_item=item)) + + item = FlowItem( + priority=p1, request_id=str(uuid.uuid4()), model_name="gpt-3.5-turbo" + ) + await scheduler.add_request(request=item) + tasks.append(_make_prioritized_call(flow_item=item)) + + # Running the tasks and getting responses in order of completion + completed_responses = [] + for task in asyncio.as_completed(tasks): + result = await task + completed_responses.append(result) + print(f"Received response: {result}") + + print(f"responses: {completed_responses}") + + assert ( + completed_responses[0][0] == 0 + ) # assert higher priority request got done first + assert ( + completed_responses[0][2] < completed_responses[1][2] + ) # higher priority request tried first diff --git a/litellm/types/router.py b/litellm/types/router.py index a35e7a77d..8fed461cb 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -314,6 +314,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): output_cost_per_token: Optional[float] input_cost_per_second: Optional[float] output_cost_per_second: Optional[float] + ## MOCK RESPONSES ## + mock_response: Optional[str] class DeploymentTypedDict(TypedDict): diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 43dcae3cd..83d387ffb 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -103,6 +103,36 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"): return await response.json() +async def queue_chat_completion( + session, key, priority: int, model: Union[str, List] = "gpt-4" +): + url = "http://0.0.0.0:4000/queue/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + "priority": priority, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return response.raw_headers + + async def chat_completion_with_headers(session, key, model="gpt-4"): url = "http://0.0.0.0:4000/chat/completions" headers = {