fix(router.py): simplify scheduler

move the scheduler poll queuing logic into the router class, making it easier to use
This commit is contained in:
Krrish Dholakia 2024-06-01 16:09:41 -07:00
parent 27087f60c2
commit 7715267989
5 changed files with 177 additions and 131 deletions

View file

@ -22,9 +22,7 @@ Prioritize LLM API requests in high-traffic.
## Quick Start ## Quick Start
```python ```python
from litellm import Scheduler, FlowItem, Router from litellm import Router
scheduler = Scheduler()
router = Router( router = Router(
model_list=[ model_list=[
@ -39,53 +37,17 @@ router = Router(
], ],
timeout=2, # timeout request if takes > 2s timeout=2, # timeout request if takes > 2s
routing_strategy="usage-based-routing-v2", routing_strategy="usage-based-routing-v2",
polling_interval=0.03 # poll queue every 3ms if no healthy deployments
) )
scheduler.update_variables(llm_router=router) try:
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
### 🚨 IMPORTANT ###
item = FlowItem(
priority=0, # 👈 SET PRIORITY FOR REQUEST
request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID
model_name="gpt-3.5-turbo" # 👈 SAME as 'Router'
)
### [fin] IMPORTANT ###
## ADDS REQUEST TO QUEUE ##
await scheduler.add_request(request=item)
## POLL QUEUE
default_timeout = router.timeout
end_time = time.time() + default_timeout
poll_interval = 0.03 # poll every 3ms
curr_time = time.time()
make_request = False
while curr_time < end_time:
make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id, model_name=item.model_name
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()
if make_request:
try:
_response = await router.acompletion(
model=item.model_name, model=item.model_name,
messages=[{"role": "user", "content": "Hey!"}], messages=[{"role": "user", "content": "Hey!"}],
priority=0, # 👈 LOWER IS BETTER
) )
except Exception as e: except Exception as e:
print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred")) print("didn't make request")
print("{}, {}, {}".format(item.priority, item.request_id, time.time()))
print("didn't make request")
``` ```
## LiteLLM Proxy ## LiteLLM Proxy

View file

@ -398,8 +398,6 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
celery_fn = None # Redis Queue for handling requests celery_fn = None # Redis Queue for handling requests
### SIMPLE QUEUE ###
simple_scheduler = Scheduler()
### DB WRITER ### ### DB WRITER ###
db_writer_client: Optional[HTTPHandler] = None db_writer_client: Optional[HTTPHandler] = None
### logger ### ### logger ###
@ -3705,7 +3703,7 @@ def on_backoff(details):
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
import json import json
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
@ -3741,10 +3739,6 @@ async def startup_event():
## Error Tracking ## ## Error Tracking ##
error_tracking() error_tracking()
## Priority Workload Scheduler ##
if llm_router is not None:
simple_scheduler.update_variables(llm_router=llm_router)
## UPDATE SLACK ALERTING ## ## UPDATE SLACK ALERTING ##
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router) proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
@ -12183,47 +12177,12 @@ async def async_queue_request(
if user_api_base: if user_api_base:
data["api_base"] = user_api_base data["api_base"] = user_api_base
## FLOW ITEM ##
request_id = str(uuid.uuid4())
flow_item = FlowItem(
priority=data.pop("priority", DefaultPriorities.Medium.value),
request_id=request_id,
model_name=data["model"],
)
# [TODO] only allow premium users to set non default priorities
## ADD REQUEST TO QUEUE
response = await simple_scheduler.add_request(request=flow_item)
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
## POLL QUEUE
default_timeout = llm_router.timeout
end_time = time.time() + default_timeout
poll_interval = 0.03 # poll every 3ms
curr_time = time.time()
make_request = False
if llm_router is None: if llm_router is None:
raise HTTPException( raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
) )
while curr_time < end_time: response = await llm_router.schedule_acompletion(**data)
make_request = await simple_scheduler.poll(
id=request_id, model_name=data["model"]
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()
if make_request:
response = await llm_router.acompletion(**data)
if ( if (
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
@ -12237,7 +12196,7 @@ async def async_queue_request(
media_type="text/event-stream", media_type="text/event-stream",
) )
fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)}) fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])})
return response return response
except Exception as e: except Exception as e:
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
@ -12260,19 +12219,6 @@ async def async_queue_request(
) )
@router.get(
"/queue/info",
tags=["experimental"],
dependencies=[Depends(user_api_key_auth)],
)
async def queue_info(
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> List:
"""Help user know the status of an item in the queue"""
return simple_scheduler.get_queue_status()
@router.get( @router.get(
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"] "/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
) )

View file

@ -62,6 +62,7 @@ from litellm.types.llms.openai import (
Run, Run,
AssistantToolParam, AssistantToolParam,
) )
from litellm.scheduler import Scheduler, FlowItem
from typing import Iterable from typing import Iterable
@ -87,6 +88,8 @@ class Router:
List[tuple] List[tuple]
] = None, # if you want to cache across model groups ] = None, # if you want to cache across model groups
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
## SCHEDULER ##
polling_interval: Optional[float] = None,
## RELIABILITY ## ## RELIABILITY ##
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@ -141,7 +144,8 @@ class Router:
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}. cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None. caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600. client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
num_retries (int): Number of retries for failed requests. Defaults to 0. polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
timeout (Optional[float]): Timeout for requests. Defaults to None. timeout (Optional[float]): Timeout for requests. Defaults to None.
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}. default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
set_verbose (bool): Flag to set verbose mode. Defaults to False. set_verbose (bool): Flag to set verbose mode. Defaults to False.
@ -208,6 +212,8 @@ class Router:
[] []
) # names of models under litellm_params. ex. azure/chatgpt-v-2 ) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {} self.deployment_latency_map = {}
### SCHEDULER ###
self.scheduler = Scheduler(polling_interval=polling_interval)
### CACHING ### ### CACHING ###
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
redis_cache = None redis_cache = None
@ -533,11 +539,17 @@ class Router:
) -> ModelResponse: ) -> ModelResponse:
... ...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
) -> Union[CustomStreamWrapper, ModelResponse]:
...
# fmt: on # fmt: on
# The actual implementation of the function # The actual implementation of the function
async def acompletion( async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
): ):
try: try:
kwargs["model"] = model kwargs["model"] = model
@ -905,6 +917,81 @@ class Router:
# If we exit the loop without returning, all tasks failed # If we exit the loop without returning, all tasks failed
raise Exception("All tasks failed") raise Exception("All tasks failed")
### SCHEDULER ###
# fmt: off
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
@overload
async def schedule_acompletion(
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
# fmt: on
async def schedule_acompletion(
self,
model: str,
messages: List[Dict[str, str]],
priority: int,
stream=False,
**kwargs,
):
### FLOW ITEM ###
_request_id = str(uuid.uuid4())
item = FlowItem(
priority=priority, # 👈 SET PRIORITY FOR REQUEST
request_id=_request_id, # 👈 SET REQUEST ID
model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
)
### [fin] ###
## ADDS REQUEST TO QUEUE ##
await self.scheduler.add_request(request=item)
## POLL QUEUE
end_time = time.time() + self.timeout
curr_time = time.time()
poll_interval = self.scheduler.polling_interval # poll every 3ms
make_request = False
while curr_time < end_time:
_healthy_deployments = await self._async_get_healthy_deployments(
model=model
)
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id,
model_name=item.model_name,
health_deployments=_healthy_deployments,
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()
if make_request:
try:
_response = await self.acompletion(
model=model, messages=messages, stream=stream, **kwargs
)
return _response
except Exception as e:
setattr(e, "priority", priority)
raise e
else:
raise litellm.Timeout(
message="Request timed out while polling queue",
model=model,
llm_provider="openai",
)
def image_generation(self, prompt: str, model: str, **kwargs): def image_generation(self, prompt: str, model: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model

View file

@ -3,7 +3,6 @@ from pydantic import BaseModel
from typing import Optional from typing import Optional
import enum import enum
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm import Router
from litellm import print_verbose from litellm import print_verbose
@ -25,14 +24,16 @@ class FlowItem(BaseModel):
class Scheduler: class Scheduler:
cache: DualCache cache: DualCache
llm_router: Optional[Router] = None
def __init__(self): def __init__(self, polling_interval: Optional[float] = None):
self.queue = [] """
polling_interval: float or null - frequency of polling queue. Default is 3ms.
"""
self.queue: list = []
self.cache = DualCache() self.cache = DualCache()
self.polling_interval = polling_interval or 0.03 # default to 3ms
def update_variables(self, llm_router: Router, cache: Optional[DualCache] = None): def update_variables(self, cache: Optional[DualCache] = None):
self.llm_router = llm_router
if cache is not None: if cache is not None:
self.cache = cache self.cache = cache
@ -46,7 +47,7 @@ class Scheduler:
# save the queue # save the queue
await self.save_queue(queue=queue, model_name=request.model_name) await self.save_queue(queue=queue, model_name=request.model_name)
async def poll(self, id: str, model_name: str) -> bool: async def poll(self, id: str, model_name: str, health_deployments: list) -> bool:
""" """
Return if request can be processed. Return if request can be processed.
@ -59,22 +60,17 @@ class Scheduler:
* AND request not at the top of queue * AND request not at the top of queue
""" """
queue = await self.get_queue(model_name=model_name) queue = await self.get_queue(model_name=model_name)
if not queue or not self.llm_router: if not queue:
raise Exception( raise Exception(
"Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( "Incorrectly setup. Queue is invalid. Queue={}".format(queue)
queue, self.llm_router
)
) )
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
_healthy_deployments = await self.llm_router._async_get_healthy_deployments(
model=model_name
)
print_verbose(f"len(_healthy_deployments): {len(_healthy_deployments)}") print_verbose(f"len(health_deployments): {len(health_deployments)}")
if len(_healthy_deployments) == 0: if len(health_deployments) == 0:
print_verbose(f"queue: {queue}, seeking id={id}") print_verbose(f"queue: {queue}, seeking id={id}")
# Check if the id is at the top of the heap # Check if the id is at the top of the heap
if queue[0][1] == id: if queue[0][1] == id:
@ -87,23 +83,19 @@ class Scheduler:
return True return True
async def peek(self, id: str, model_name: str) -> bool: async def peek(self, id: str, model_name: str, health_deployments: list) -> bool:
"""Return if the id is at the top of the queue. Don't pop the value from heap.""" """Return if the id is at the top of the queue. Don't pop the value from heap."""
queue = await self.get_queue(model_name=model_name) queue = await self.get_queue(model_name=model_name)
if not queue or not self.llm_router: if not queue:
raise Exception( raise Exception(
"Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format( "Incorrectly setup. Queue is invalid. Queue={}".format(queue)
queue, self.llm_router
)
) )
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
_healthy_deployments = await self.llm_router._async_get_healthy_deployments(
model=model_name if len(health_deployments) == 0:
)
if len(_healthy_deployments) == 0:
return False return False
# Check if the id is at the top of the heap # Check if the id is at the top of the heap

View file

@ -4,12 +4,14 @@
import sys, os, time, openai, uuid import sys, os, time, openai, uuid
import traceback, asyncio import traceback, asyncio
import pytest import pytest
from typing import List
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import Router from litellm import Router
from litellm.scheduler import FlowItem, Scheduler from litellm.scheduler import FlowItem, Scheduler
from litellm import ModelResponse
@pytest.mark.asyncio @pytest.mark.asyncio
@ -172,3 +174,60 @@ async def test_aascheduler_prioritized_requests_mock_response(p0, p1):
assert ( assert (
completed_responses[0][2] < completed_responses[1][2] completed_responses[0][2] < completed_responses[1][2]
) # higher priority request tried first ) # higher priority request tried first
@pytest.mark.parametrize("p0, p1", [(0, 1), (0, 0)]) #
@pytest.mark.asyncio
async def test_aascheduler_prioritized_requests_mock_response_simplified(p0, p1):
"""
2 requests for same model group
if model is at rate limit, ensure the higher priority request gets done first
"""
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "Hello world this is Macintosh!",
"rpm": 0,
},
},
],
timeout=10,
num_retries=3,
cooldown_time=5,
routing_strategy="usage-based-routing-v2",
)
tasks = []
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
tasks.append(router.schedule_acompletion(**data, priority=p0))
tasks.append(router.schedule_acompletion(**data, priority=p1))
# Running the tasks and getting responses in order of completion
completed_responses: List[dict] = []
for task in asyncio.as_completed(tasks):
try:
result = await task
except Exception as e:
result = {"priority": e.priority, "response_completed_at": time.time()}
completed_responses.append(result)
print(f"Received response: {result}")
print(f"responses: {completed_responses}")
assert (
completed_responses[0]["priority"] == 0
) # assert higher priority request got done first
assert (
completed_responses[0]["response_completed_at"]
< completed_responses[1]["response_completed_at"]
) # higher priority request tried first