mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(router.py): simplify scheduler
move the scheduler poll queuing logic into the router class, making it easier to use
This commit is contained in:
parent
27087f60c2
commit
7715267989
5 changed files with 177 additions and 131 deletions
|
@ -22,9 +22,7 @@ Prioritize LLM API requests in high-traffic.
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import Scheduler, FlowItem, Router
|
from litellm import Router
|
||||||
|
|
||||||
scheduler = Scheduler()
|
|
||||||
|
|
||||||
router = Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
|
@ -39,52 +37,16 @@ router = Router(
|
||||||
],
|
],
|
||||||
timeout=2, # timeout request if takes > 2s
|
timeout=2, # timeout request if takes > 2s
|
||||||
routing_strategy="usage-based-routing-v2",
|
routing_strategy="usage-based-routing-v2",
|
||||||
|
polling_interval=0.03 # poll queue every 3ms if no healthy deployments
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler.update_variables(llm_router=router)
|
|
||||||
|
|
||||||
### 🚨 IMPORTANT ###
|
|
||||||
|
|
||||||
item = FlowItem(
|
|
||||||
priority=0, # 👈 SET PRIORITY FOR REQUEST
|
|
||||||
request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID
|
|
||||||
model_name="gpt-3.5-turbo" # 👈 SAME as 'Router'
|
|
||||||
)
|
|
||||||
|
|
||||||
### [fin] IMPORTANT ###
|
|
||||||
|
|
||||||
## ADDS REQUEST TO QUEUE ##
|
|
||||||
await scheduler.add_request(request=item)
|
|
||||||
|
|
||||||
## POLL QUEUE
|
|
||||||
default_timeout = router.timeout
|
|
||||||
end_time = time.time() + default_timeout
|
|
||||||
poll_interval = 0.03 # poll every 3ms
|
|
||||||
curr_time = time.time()
|
|
||||||
|
|
||||||
make_request = False
|
|
||||||
|
|
||||||
while curr_time < end_time:
|
|
||||||
make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
|
||||||
id=item.request_id, model_name=item.model_name
|
|
||||||
)
|
|
||||||
if make_request: ## IF TRUE -> MAKE REQUEST
|
|
||||||
break
|
|
||||||
else: ## ELSE -> loop till default_timeout
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
curr_time = time.time()
|
|
||||||
|
|
||||||
if make_request:
|
|
||||||
try:
|
try:
|
||||||
_response = await router.acompletion(
|
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
|
||||||
model=item.model_name,
|
model=item.model_name,
|
||||||
messages=[{"role": "user", "content": "Hey!"}],
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
|
priority=0, # 👈 LOWER IS BETTER
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred"))
|
|
||||||
|
|
||||||
print("{}, {}, {}".format(item.priority, item.request_id, time.time()))
|
|
||||||
|
|
||||||
print("didn't make request")
|
print("didn't make request")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -398,8 +398,6 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
celery_fn = None # Redis Queue for handling requests
|
celery_fn = None # Redis Queue for handling requests
|
||||||
### SIMPLE QUEUE ###
|
|
||||||
simple_scheduler = Scheduler()
|
|
||||||
### DB WRITER ###
|
### DB WRITER ###
|
||||||
db_writer_client: Optional[HTTPHandler] = None
|
db_writer_client: Optional[HTTPHandler] = None
|
||||||
### logger ###
|
### logger ###
|
||||||
|
@ -3705,7 +3703,7 @@ def on_backoff(details):
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler
|
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
|
||||||
import json
|
import json
|
||||||
|
|
||||||
### LOAD MASTER KEY ###
|
### LOAD MASTER KEY ###
|
||||||
|
@ -3741,10 +3739,6 @@ async def startup_event():
|
||||||
## Error Tracking ##
|
## Error Tracking ##
|
||||||
error_tracking()
|
error_tracking()
|
||||||
|
|
||||||
## Priority Workload Scheduler ##
|
|
||||||
if llm_router is not None:
|
|
||||||
simple_scheduler.update_variables(llm_router=llm_router)
|
|
||||||
|
|
||||||
## UPDATE SLACK ALERTING ##
|
## UPDATE SLACK ALERTING ##
|
||||||
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
|
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
|
||||||
|
|
||||||
|
@ -12183,47 +12177,12 @@ async def async_queue_request(
|
||||||
if user_api_base:
|
if user_api_base:
|
||||||
data["api_base"] = user_api_base
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
## FLOW ITEM ##
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
flow_item = FlowItem(
|
|
||||||
priority=data.pop("priority", DefaultPriorities.Medium.value),
|
|
||||||
request_id=request_id,
|
|
||||||
model_name=data["model"],
|
|
||||||
)
|
|
||||||
# [TODO] only allow premium users to set non default priorities
|
|
||||||
|
|
||||||
## ADD REQUEST TO QUEUE
|
|
||||||
response = await simple_scheduler.add_request(request=flow_item)
|
|
||||||
|
|
||||||
if llm_router is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
|
||||||
)
|
|
||||||
## POLL QUEUE
|
|
||||||
default_timeout = llm_router.timeout
|
|
||||||
end_time = time.time() + default_timeout
|
|
||||||
poll_interval = 0.03 # poll every 3ms
|
|
||||||
curr_time = time.time()
|
|
||||||
|
|
||||||
make_request = False
|
|
||||||
|
|
||||||
if llm_router is None:
|
if llm_router is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
|
|
||||||
while curr_time < end_time:
|
response = await llm_router.schedule_acompletion(**data)
|
||||||
make_request = await simple_scheduler.poll(
|
|
||||||
id=request_id, model_name=data["model"]
|
|
||||||
)
|
|
||||||
if make_request: ## IF TRUE -> MAKE REQUEST
|
|
||||||
break
|
|
||||||
else: ## ELSE -> loop till default_timeout
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
curr_time = time.time()
|
|
||||||
|
|
||||||
if make_request:
|
|
||||||
response = await llm_router.acompletion(**data)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
|
@ -12237,7 +12196,7 @@ async def async_queue_request(
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)})
|
fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])})
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
@ -12260,19 +12219,6 @@ async def async_queue_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/queue/info",
|
|
||||||
tags=["experimental"],
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
)
|
|
||||||
async def queue_info(
|
|
||||||
request: Request,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
) -> List:
|
|
||||||
"""Help user know the status of an item in the queue"""
|
|
||||||
return simple_scheduler.get_queue_status()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
|
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
|
||||||
)
|
)
|
||||||
|
|
|
@ -62,6 +62,7 @@ from litellm.types.llms.openai import (
|
||||||
Run,
|
Run,
|
||||||
AssistantToolParam,
|
AssistantToolParam,
|
||||||
)
|
)
|
||||||
|
from litellm.scheduler import Scheduler, FlowItem
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,6 +88,8 @@ class Router:
|
||||||
List[tuple]
|
List[tuple]
|
||||||
] = None, # if you want to cache across model groups
|
] = None, # if you want to cache across model groups
|
||||||
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
||||||
|
## SCHEDULER ##
|
||||||
|
polling_interval: Optional[float] = None,
|
||||||
## RELIABILITY ##
|
## RELIABILITY ##
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
@ -141,7 +144,8 @@ class Router:
|
||||||
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
|
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
|
||||||
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
|
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
|
||||||
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
|
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
|
||||||
num_retries (int): Number of retries for failed requests. Defaults to 0.
|
polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
|
||||||
|
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
|
||||||
timeout (Optional[float]): Timeout for requests. Defaults to None.
|
timeout (Optional[float]): Timeout for requests. Defaults to None.
|
||||||
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
|
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
|
||||||
set_verbose (bool): Flag to set verbose mode. Defaults to False.
|
set_verbose (bool): Flag to set verbose mode. Defaults to False.
|
||||||
|
@ -208,6 +212,8 @@ class Router:
|
||||||
[]
|
[]
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
|
### SCHEDULER ###
|
||||||
|
self.scheduler = Scheduler(polling_interval=polling_interval)
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
|
@ -533,11 +539,17 @@ class Router:
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
|
||||||
|
) -> Union[CustomStreamWrapper, ModelResponse]:
|
||||||
|
...
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# The actual implementation of the function
|
# The actual implementation of the function
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
|
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
@ -905,6 +917,81 @@ class Router:
|
||||||
# If we exit the loop without returning, all tasks failed
|
# If we exit the loop without returning, all tasks failed
|
||||||
raise Exception("All tasks failed")
|
raise Exception("All tasks failed")
|
||||||
|
|
||||||
|
### SCHEDULER ###
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def schedule_acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
|
||||||
|
) -> ModelResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def schedule_acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
async def schedule_acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
priority: int,
|
||||||
|
stream=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
### FLOW ITEM ###
|
||||||
|
_request_id = str(uuid.uuid4())
|
||||||
|
item = FlowItem(
|
||||||
|
priority=priority, # 👈 SET PRIORITY FOR REQUEST
|
||||||
|
request_id=_request_id, # 👈 SET REQUEST ID
|
||||||
|
model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
|
||||||
|
)
|
||||||
|
### [fin] ###
|
||||||
|
|
||||||
|
## ADDS REQUEST TO QUEUE ##
|
||||||
|
await self.scheduler.add_request(request=item)
|
||||||
|
|
||||||
|
## POLL QUEUE
|
||||||
|
end_time = time.time() + self.timeout
|
||||||
|
curr_time = time.time()
|
||||||
|
poll_interval = self.scheduler.polling_interval # poll every 3ms
|
||||||
|
make_request = False
|
||||||
|
|
||||||
|
while curr_time < end_time:
|
||||||
|
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
||||||
|
id=item.request_id,
|
||||||
|
model_name=item.model_name,
|
||||||
|
health_deployments=_healthy_deployments,
|
||||||
|
)
|
||||||
|
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||||
|
break
|
||||||
|
else: ## ELSE -> loop till default_timeout
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
if make_request:
|
||||||
|
try:
|
||||||
|
_response = await self.acompletion(
|
||||||
|
model=model, messages=messages, stream=stream, **kwargs
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
except Exception as e:
|
||||||
|
setattr(e, "priority", priority)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
raise litellm.Timeout(
|
||||||
|
message="Request timed out while polling queue",
|
||||||
|
model=model,
|
||||||
|
llm_provider="openai",
|
||||||
|
)
|
||||||
|
|
||||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
|
|
@ -3,7 +3,6 @@ from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import enum
|
import enum
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm import Router
|
|
||||||
from litellm import print_verbose
|
from litellm import print_verbose
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,14 +24,16 @@ class FlowItem(BaseModel):
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
cache: DualCache
|
cache: DualCache
|
||||||
llm_router: Optional[Router] = None
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, polling_interval: Optional[float] = None):
|
||||||
self.queue = []
|
"""
|
||||||
|
polling_interval: float or null - frequency of polling queue. Default is 3ms.
|
||||||
|
"""
|
||||||
|
self.queue: list = []
|
||||||
self.cache = DualCache()
|
self.cache = DualCache()
|
||||||
|
self.polling_interval = polling_interval or 0.03 # default to 3ms
|
||||||
|
|
||||||
def update_variables(self, llm_router: Router, cache: Optional[DualCache] = None):
|
def update_variables(self, cache: Optional[DualCache] = None):
|
||||||
self.llm_router = llm_router
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ class Scheduler:
|
||||||
# save the queue
|
# save the queue
|
||||||
await self.save_queue(queue=queue, model_name=request.model_name)
|
await self.save_queue(queue=queue, model_name=request.model_name)
|
||||||
|
|
||||||
async def poll(self, id: str, model_name: str) -> bool:
|
async def poll(self, id: str, model_name: str, health_deployments: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Return if request can be processed.
|
Return if request can be processed.
|
||||||
|
|
||||||
|
@ -59,22 +60,17 @@ class Scheduler:
|
||||||
* AND request not at the top of queue
|
* AND request not at the top of queue
|
||||||
"""
|
"""
|
||||||
queue = await self.get_queue(model_name=model_name)
|
queue = await self.get_queue(model_name=model_name)
|
||||||
if not queue or not self.llm_router:
|
if not queue:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format(
|
"Incorrectly setup. Queue is invalid. Queue={}".format(queue)
|
||||||
queue, self.llm_router
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
_healthy_deployments = await self.llm_router._async_get_healthy_deployments(
|
|
||||||
model=model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
print_verbose(f"len(_healthy_deployments): {len(_healthy_deployments)}")
|
print_verbose(f"len(health_deployments): {len(health_deployments)}")
|
||||||
if len(_healthy_deployments) == 0:
|
if len(health_deployments) == 0:
|
||||||
print_verbose(f"queue: {queue}, seeking id={id}")
|
print_verbose(f"queue: {queue}, seeking id={id}")
|
||||||
# Check if the id is at the top of the heap
|
# Check if the id is at the top of the heap
|
||||||
if queue[0][1] == id:
|
if queue[0][1] == id:
|
||||||
|
@ -87,23 +83,19 @@ class Scheduler:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def peek(self, id: str, model_name: str) -> bool:
|
async def peek(self, id: str, model_name: str, health_deployments: list) -> bool:
|
||||||
"""Return if the id is at the top of the queue. Don't pop the value from heap."""
|
"""Return if the id is at the top of the queue. Don't pop the value from heap."""
|
||||||
queue = await self.get_queue(model_name=model_name)
|
queue = await self.get_queue(model_name=model_name)
|
||||||
if not queue or not self.llm_router:
|
if not queue:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format(
|
"Incorrectly setup. Queue is invalid. Queue={}".format(queue)
|
||||||
queue, self.llm_router
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
_healthy_deployments = await self.llm_router._async_get_healthy_deployments(
|
|
||||||
model=model_name
|
if len(health_deployments) == 0:
|
||||||
)
|
|
||||||
if len(_healthy_deployments) == 0:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if the id is at the top of the heap
|
# Check if the id is at the top of the heap
|
||||||
|
|
|
@ -4,12 +4,14 @@
|
||||||
import sys, os, time, openai, uuid
|
import sys, os, time, openai, uuid
|
||||||
import traceback, asyncio
|
import traceback, asyncio
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.scheduler import FlowItem, Scheduler
|
from litellm.scheduler import FlowItem, Scheduler
|
||||||
|
from litellm import ModelResponse
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -172,3 +174,60 @@ async def test_aascheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
assert (
|
assert (
|
||||||
completed_responses[0][2] < completed_responses[1][2]
|
completed_responses[0][2] < completed_responses[1][2]
|
||||||
) # higher priority request tried first
|
) # higher priority request tried first
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("p0, p1", [(0, 1), (0, 0)]) #
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aascheduler_prioritized_requests_mock_response_simplified(p0, p1):
|
||||||
|
"""
|
||||||
|
2 requests for same model group
|
||||||
|
|
||||||
|
if model is at rate limit, ensure the higher priority request gets done first
|
||||||
|
"""
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"mock_response": "Hello world this is Macintosh!",
|
||||||
|
"rpm": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
timeout=10,
|
||||||
|
num_retries=3,
|
||||||
|
cooldown_time=5,
|
||||||
|
routing_strategy="usage-based-routing-v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.append(router.schedule_acompletion(**data, priority=p0))
|
||||||
|
tasks.append(router.schedule_acompletion(**data, priority=p1))
|
||||||
|
|
||||||
|
# Running the tasks and getting responses in order of completion
|
||||||
|
completed_responses: List[dict] = []
|
||||||
|
for task in asyncio.as_completed(tasks):
|
||||||
|
try:
|
||||||
|
result = await task
|
||||||
|
except Exception as e:
|
||||||
|
result = {"priority": e.priority, "response_completed_at": time.time()}
|
||||||
|
completed_responses.append(result)
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
|
||||||
|
print(f"responses: {completed_responses}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
completed_responses[0]["priority"] == 0
|
||||||
|
) # assert higher priority request got done first
|
||||||
|
assert (
|
||||||
|
completed_responses[0]["response_completed_at"]
|
||||||
|
< completed_responses[1]["response_completed_at"]
|
||||||
|
) # higher priority request tried first
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue