mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #3954 from BerriAI/litellm_simple_request_prioritization
feat(scheduler.py): add request prioritization scheduler
This commit is contained in:
commit
8375e9621c
12 changed files with 612 additions and 149 deletions
141
docs/my-website/docs/scheduler.md
Normal file
141
docs/my-website/docs/scheduler.md
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# [BETA] Request Prioritization
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Beta feature. Use for testing only.
|
||||||
|
|
||||||
|
[Help us improve this](https://github.com/BerriAI/litellm/issues)
|
||||||
|
:::
|
||||||
|
|
||||||
|
Prioritize LLM API requests in high-traffic.
|
||||||
|
|
||||||
|
- Add request to priority queue
|
||||||
|
- Poll queue, to check if request can be made. Returns 'True':
|
||||||
|
* if there's healthy deployments
|
||||||
|
* OR if request is at top of queue
|
||||||
|
- Priority - The lower the number, the higher the priority:
|
||||||
|
* e.g. `priority=0` > `priority=2000`
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Scheduler, FlowItem, Router
|
||||||
|
|
||||||
|
scheduler = Scheduler()
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"mock_response": "Hello world this is Macintosh!", # fakes the LLM API call
|
||||||
|
"rpm": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
timeout=2, # timeout request if takes > 2s
|
||||||
|
routing_strategy="usage-based-routing-v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.update_variables(llm_router=router)
|
||||||
|
|
||||||
|
### 🚨 IMPORTANT ###
|
||||||
|
|
||||||
|
item = FlowItem(
|
||||||
|
priority=0, # 👈 SET PRIORITY FOR REQUEST
|
||||||
|
request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID
|
||||||
|
model_name="gpt-3.5-turbo" # 👈 SAME as 'Router'
|
||||||
|
)
|
||||||
|
|
||||||
|
### [fin] IMPORTANT ###
|
||||||
|
|
||||||
|
## ADDS REQUEST TO QUEUE ##
|
||||||
|
await scheduler.add_request(request=item)
|
||||||
|
|
||||||
|
## POLL QUEUE
|
||||||
|
default_timeout = router.timeout
|
||||||
|
end_time = time.time() + default_timeout
|
||||||
|
poll_interval = 0.03 # poll every 3ms
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
make_request = False
|
||||||
|
|
||||||
|
while curr_time < end_time:
|
||||||
|
make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
||||||
|
id=item.request_id, model_name=item.model_name
|
||||||
|
)
|
||||||
|
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||||
|
break
|
||||||
|
else: ## ELSE -> loop till default_timeout
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
if make_request:
|
||||||
|
try:
|
||||||
|
_response = await router.acompletion(
|
||||||
|
model=item.model_name,
|
||||||
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred"))
|
||||||
|
|
||||||
|
print("{}, {}, {}".format(item.priority, item.request_id, time.time()))
|
||||||
|
|
||||||
|
print("didn't make request")
|
||||||
|
```
|
||||||
|
|
||||||
|
## LiteLLM Proxy
|
||||||
|
|
||||||
|
To prioritize requests on LiteLLM Proxy call our beta openai-compatible `http://localhost:4000/queue` endpoint.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```curl
|
||||||
|
curl -X POST 'http://localhost:4000/queue/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-D '{
|
||||||
|
"model": "gpt-3.5-turbo-fake-model",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the meaning of the universe? 1234"
|
||||||
|
}],
|
||||||
|
"priority": 0 👈 SET VALUE HERE
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="openai-sdk" label="OpenAI SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"priority": 0 👈 SET VALUE HERE
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
|
@ -164,6 +164,7 @@ const sidebars = {
|
||||||
},
|
},
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
"routing",
|
"routing",
|
||||||
|
"scheduler",
|
||||||
"rules",
|
"rules",
|
||||||
"set_keys",
|
"set_keys",
|
||||||
"budget_manager",
|
"budget_manager",
|
||||||
|
|
|
@ -808,3 +808,4 @@ from .proxy.proxy_cli import run_server
|
||||||
from .router import Router
|
from .router import Router
|
||||||
from .assistants.main import *
|
from .assistants.main import *
|
||||||
from .batches.main import *
|
from .batches.main import *
|
||||||
|
from .scheduler import *
|
||||||
|
|
|
@ -431,6 +431,10 @@ def mock_completion(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
time_delay = kwargs.get("mock_delay", None)
|
||||||
|
if time_delay is not None:
|
||||||
|
time.sleep(time_delay)
|
||||||
|
|
||||||
model_response = ModelResponse(stream=stream)
|
model_response = ModelResponse(stream=stream)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
@ -881,6 +885,7 @@ def completion(
|
||||||
mock_response=mock_response,
|
mock_response=mock_response,
|
||||||
logging=logging,
|
logging=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
mock_delay=kwargs.get("mock_delay", None),
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
|
|
|
@ -5,12 +5,12 @@ model_list:
|
||||||
model: openai/my-fake-model
|
model: openai/my-fake-model
|
||||||
rpm: 800
|
rpm: 800
|
||||||
model_name: gpt-3.5-turbo-fake-model
|
model_name: gpt-3.5-turbo-fake-model
|
||||||
- litellm_params:
|
# - litellm_params:
|
||||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
# api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
model: azure/gpt-35-turbo
|
# model: azure/gpt-35-turbo
|
||||||
rpm: 10
|
# rpm: 10
|
||||||
model_name: gpt-3.5-turbo-fake-model
|
# model_name: gpt-3.5-turbo-fake-model
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
|
|
@ -142,6 +142,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
from litellm.exceptions import RejectedRequestError
|
from litellm.exceptions import RejectedRequestError
|
||||||
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
|
||||||
|
from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -397,6 +398,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
celery_fn = None # Redis Queue for handling requests
|
celery_fn = None # Redis Queue for handling requests
|
||||||
|
### SIMPLE QUEUE ###
|
||||||
|
simple_scheduler = Scheduler()
|
||||||
### DB WRITER ###
|
### DB WRITER ###
|
||||||
db_writer_client: Optional[HTTPHandler] = None
|
db_writer_client: Optional[HTTPHandler] = None
|
||||||
### logger ###
|
### logger ###
|
||||||
|
@ -3702,7 +3705,7 @@ def on_backoff(details):
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
|
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler
|
||||||
import json
|
import json
|
||||||
|
|
||||||
### LOAD MASTER KEY ###
|
### LOAD MASTER KEY ###
|
||||||
|
@ -3738,6 +3741,10 @@ async def startup_event():
|
||||||
## Error Tracking ##
|
## Error Tracking ##
|
||||||
error_tracking()
|
error_tracking()
|
||||||
|
|
||||||
|
## Priority Workload Scheduler ##
|
||||||
|
if llm_router is not None:
|
||||||
|
simple_scheduler.update_variables(llm_router=llm_router)
|
||||||
|
|
||||||
## UPDATE SLACK ALERTING ##
|
## UPDATE SLACK ALERTING ##
|
||||||
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
|
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
|
||||||
|
|
||||||
|
@ -12076,118 +12083,7 @@ async def alerting_settings(
|
||||||
return return_val
|
return return_val
|
||||||
|
|
||||||
|
|
||||||
# @router.post(
|
|
||||||
# "/alerting/update",
|
|
||||||
# description="Update the slack alerting settings. Persist value in db.",
|
|
||||||
# tags=["alerting"],
|
|
||||||
# dependencies=[Depends(user_api_key_auth)],
|
|
||||||
# include_in_schema=False,
|
|
||||||
# )
|
|
||||||
# async def alerting_update(
|
|
||||||
# data: SlackAlertingArgs,
|
|
||||||
# user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
# ):
|
|
||||||
# """Allows updating slack alerting values. Used by UI."""
|
|
||||||
# global prisma_client
|
|
||||||
# if prisma_client is None:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=400,
|
|
||||||
# detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=400,
|
|
||||||
# detail={"error": CommonProxyErrors.not_allowed_access.value},
|
|
||||||
# )
|
|
||||||
|
|
||||||
# ## get general settings from db
|
|
||||||
# db_general_settings = await prisma_client.db.litellm_config.find_first(
|
|
||||||
# where={"param_name": "general_settings"}
|
|
||||||
# )
|
|
||||||
# ### update value
|
|
||||||
|
|
||||||
# alerting_args_dict = {}
|
|
||||||
# if db_general_settings is None or db_general_settings.param_value is None:
|
|
||||||
# general_settings = {}
|
|
||||||
# alerting_args_dict = {}
|
|
||||||
# else:
|
|
||||||
# general_settings = dict(db_general_settings.param_value)
|
|
||||||
# _alerting_args_dict = general_settings.get("alerting_args", None)
|
|
||||||
# if _alerting_args_dict is not None and isinstance(_alerting_args_dict, dict):
|
|
||||||
# alerting_args_dict = _alerting_args_dict
|
|
||||||
|
|
||||||
|
|
||||||
# alerting_args_dict = data.model
|
|
||||||
|
|
||||||
# response = await prisma_client.db.litellm_config.upsert(
|
|
||||||
# where={"param_name": "general_settings"},
|
|
||||||
# data={
|
|
||||||
# "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
|
|
||||||
# "update": {"param_value": json.dumps(general_settings)}, # type: ignore
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return response
|
|
||||||
|
|
||||||
|
|
||||||
#### EXPERIMENTAL QUEUING ####
|
#### EXPERIMENTAL QUEUING ####
|
||||||
async def _litellm_chat_completions_worker(data, user_api_key_dict):
|
|
||||||
"""
|
|
||||||
worker to make litellm completions calls
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("_litellm_chat_completions_worker started")
|
|
||||||
### ROUTE THE REQUEST ###
|
|
||||||
router_model_names = (
|
|
||||||
[m["model_name"] for m in llm_model_list]
|
|
||||||
if llm_model_list is not None
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
llm_router is not None and data["model"] in router_model_names
|
|
||||||
): # model in router model list
|
|
||||||
response = await llm_router.acompletion(**data)
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.acompletion(
|
|
||||||
**data, specific_deployment=True
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
response = await llm_router.acompletion(**data)
|
|
||||||
else: # router is not set
|
|
||||||
response = await litellm.acompletion(**data)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("final response: {response}")
|
|
||||||
return response
|
|
||||||
except HTTPException as e:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
e.status_code == 429
|
|
||||||
and "Max parallel request limit reached" in e.detail
|
|
||||||
):
|
|
||||||
verbose_proxy_logger.debug("Max parallel request limit reached!")
|
|
||||||
timeout = litellm._calculate_retry_after(
|
|
||||||
remaining_retries=3, max_retries=3, min_timeout=1
|
|
||||||
)
|
|
||||||
await asyncio.sleep(timeout)
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/queue/chat/completions",
|
"/queue/chat/completions",
|
||||||
tags=["experimental"],
|
tags=["experimental"],
|
||||||
|
@ -12195,6 +12091,7 @@ async def _litellm_chat_completions_worker(data, user_api_key_dict):
|
||||||
)
|
)
|
||||||
async def async_queue_request(
|
async def async_queue_request(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
@ -12260,12 +12157,47 @@ async def async_queue_request(
|
||||||
if user_api_base:
|
if user_api_base:
|
||||||
data["api_base"] = user_api_base
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
response = await asyncio.wait_for(
|
## FLOW ITEM ##
|
||||||
_litellm_chat_completions_worker(
|
request_id = str(uuid.uuid4())
|
||||||
data=data, user_api_key_dict=user_api_key_dict
|
flow_item = FlowItem(
|
||||||
),
|
priority=data.pop("priority", DefaultPriorities.Medium.value),
|
||||||
timeout=litellm.request_timeout,
|
request_id=request_id,
|
||||||
|
model_name=data["model"],
|
||||||
)
|
)
|
||||||
|
# [TODO] only allow premium users to set non default priorities
|
||||||
|
|
||||||
|
## ADD REQUEST TO QUEUE
|
||||||
|
response = await simple_scheduler.add_request(request=flow_item)
|
||||||
|
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
## POLL QUEUE
|
||||||
|
default_timeout = llm_router.timeout
|
||||||
|
end_time = time.time() + default_timeout
|
||||||
|
poll_interval = 0.03 # poll every 3ms
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
make_request = False
|
||||||
|
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
while curr_time < end_time:
|
||||||
|
make_request = await simple_scheduler.poll(
|
||||||
|
id=request_id, model_name=data["model"]
|
||||||
|
)
|
||||||
|
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||||
|
break
|
||||||
|
else: ## ELSE -> loop till default_timeout
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
if make_request:
|
||||||
|
response = await llm_router.acompletion(**data)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
|
@ -12279,6 +12211,7 @@ async def async_queue_request(
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)})
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
@ -12301,6 +12234,19 @@ async def async_queue_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/queue/info",
|
||||||
|
tags=["experimental"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def queue_info(
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
) -> List:
|
||||||
|
"""Help user know the status of an item in the queue"""
|
||||||
|
return simple_scheduler.get_queue_status()
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
|
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
|
||||||
)
|
)
|
||||||
|
|
|
@ -641,7 +641,6 @@ class Router:
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
client_type="max_parallel_requests",
|
client_type="max_parallel_requests",
|
||||||
)
|
)
|
||||||
|
|
||||||
if rpm_semaphore is not None and isinstance(
|
if rpm_semaphore is not None and isinstance(
|
||||||
rpm_semaphore, asyncio.Semaphore
|
rpm_semaphore, asyncio.Semaphore
|
||||||
):
|
):
|
||||||
|
@ -1987,6 +1986,7 @@ class Router:
|
||||||
error=e,
|
error=e,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
regular_fallbacks=fallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decides how long to sleep before retry
|
# decides how long to sleep before retry
|
||||||
|
@ -1996,7 +1996,6 @@ class Router:
|
||||||
num_retries=num_retries,
|
num_retries=num_retries,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
# sleeps for the length of the timeout
|
# sleeps for the length of the timeout
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
|
|
||||||
|
@ -2041,6 +2040,7 @@ class Router:
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cooldown_deployments = await self._async_get_cooldown_deployments()
|
cooldown_deployments = await self._async_get_cooldown_deployments()
|
||||||
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}"
|
original_exception.message += f"\nNumber Retries = {current_attempt + 1}, Max Retries={num_retries}\nCooldown Deployments={cooldown_deployments}"
|
||||||
|
@ -2053,6 +2053,7 @@ class Router:
|
||||||
error: Exception,
|
error: Exception,
|
||||||
healthy_deployments: Optional[List] = None,
|
healthy_deployments: Optional[List] = None,
|
||||||
context_window_fallbacks: Optional[List] = None,
|
context_window_fallbacks: Optional[List] = None,
|
||||||
|
regular_fallbacks: Optional[List] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
1. raise an exception for ContextWindowExceededError if context_window_fallbacks is not None
|
||||||
|
@ -2069,7 +2070,7 @@ class Router:
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||||
if (
|
if (
|
||||||
isinstance(error, litellm.ContextWindowExceededError)
|
isinstance(error, litellm.ContextWindowExceededError)
|
||||||
and context_window_fallbacks is None
|
and context_window_fallbacks is not None
|
||||||
):
|
):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
@ -2077,7 +2078,11 @@ class Router:
|
||||||
if isinstance(error, openai.RateLimitError) or isinstance(
|
if isinstance(error, openai.RateLimitError) or isinstance(
|
||||||
error, openai.AuthenticationError
|
error, openai.AuthenticationError
|
||||||
):
|
):
|
||||||
if _num_healthy_deployments <= 0:
|
if (
|
||||||
|
_num_healthy_deployments <= 0
|
||||||
|
and regular_fallbacks is not None
|
||||||
|
and len(regular_fallbacks) > 0
|
||||||
|
):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -2252,6 +2257,7 @@ class Router:
|
||||||
error=e,
|
error=e,
|
||||||
healthy_deployments=_healthy_deployments,
|
healthy_deployments=_healthy_deployments,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
|
regular_fallbacks=fallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decides how long to sleep before retry
|
# decides how long to sleep before retry
|
||||||
|
@ -2460,7 +2466,7 @@ class Router:
|
||||||
|
|
||||||
the exception is not one that should be immediately retried (e.g. 401)
|
the exception is not one that should be immediately retried (e.g. 401)
|
||||||
"""
|
"""
|
||||||
args = locals()
|
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -2631,7 +2637,17 @@ class Router:
|
||||||
"""
|
"""
|
||||||
for _callback in litellm.callbacks:
|
for _callback in litellm.callbacks:
|
||||||
if isinstance(_callback, CustomLogger):
|
if isinstance(_callback, CustomLogger):
|
||||||
response = await _callback.async_pre_call_check(deployment)
|
try:
|
||||||
|
response = await _callback.async_pre_call_check(deployment)
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
self._set_cooldown_deployments(
|
||||||
|
exception_status=e.status_code,
|
||||||
|
deployment=deployment["model_info"]["id"],
|
||||||
|
time_to_cooldown=self.cooldown_time,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def set_client(self, model: dict):
|
def set_client(self, model: dict):
|
||||||
"""
|
"""
|
||||||
|
|
139
litellm/scheduler.py
Normal file
139
litellm/scheduler.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
import heapq, time
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
import enum
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm import Router
|
||||||
|
from litellm import print_verbose
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerCacheKeys(enum.Enum):
|
||||||
|
queue = "scheduler:queue"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultPriorities(enum.Enum):
|
||||||
|
High = 0
|
||||||
|
Medium = 128
|
||||||
|
Low = 255
|
||||||
|
|
||||||
|
|
||||||
|
class FlowItem(BaseModel):
|
||||||
|
priority: int # Priority between 0 and 255
|
||||||
|
request_id: str
|
||||||
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler:
|
||||||
|
cache: DualCache
|
||||||
|
llm_router: Optional[Router] = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.queue = []
|
||||||
|
self.cache = DualCache()
|
||||||
|
|
||||||
|
def update_variables(self, llm_router: Router, cache: Optional[DualCache] = None):
|
||||||
|
self.llm_router = llm_router
|
||||||
|
if cache is not None:
|
||||||
|
self.cache = cache
|
||||||
|
|
||||||
|
async def add_request(self, request: FlowItem):
|
||||||
|
# We use the priority directly, as lower values indicate higher priority
|
||||||
|
# get the queue
|
||||||
|
queue = await self.get_queue(model_name=request.model_name)
|
||||||
|
# update the queue
|
||||||
|
heapq.heappush(queue, (request.priority, request.request_id))
|
||||||
|
|
||||||
|
# save the queue
|
||||||
|
await self.save_queue(queue=queue, model_name=request.model_name)
|
||||||
|
|
||||||
|
async def poll(self, id: str, model_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return if request can be processed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- True:
|
||||||
|
* If healthy deployments are available
|
||||||
|
* OR If request at the top of queue
|
||||||
|
- False:
|
||||||
|
* If no healthy deployments available
|
||||||
|
* AND request not at the top of queue
|
||||||
|
"""
|
||||||
|
queue = await self.get_queue(model_name=model_name)
|
||||||
|
if not queue or not self.llm_router:
|
||||||
|
raise Exception(
|
||||||
|
"Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format(
|
||||||
|
queue, self.llm_router
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
_healthy_deployments = await self.llm_router._async_get_healthy_deployments(
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"len(_healthy_deployments): {len(_healthy_deployments)}")
|
||||||
|
if len(_healthy_deployments) == 0:
|
||||||
|
print_verbose(f"queue: {queue}, seeking id={id}")
|
||||||
|
# Check if the id is at the top of the heap
|
||||||
|
if queue[0][1] == id:
|
||||||
|
# Remove the item from the queue
|
||||||
|
heapq.heappop(queue)
|
||||||
|
print_verbose(f"Popped id: {id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def peek(self, id: str, model_name: str) -> bool:
|
||||||
|
"""Return if the id is at the top of the queue. Don't pop the value from heap."""
|
||||||
|
queue = await self.get_queue(model_name=model_name)
|
||||||
|
if not queue or not self.llm_router:
|
||||||
|
raise Exception(
|
||||||
|
"Incorrectly setup. Queue or Router is invalid. Queue={}, Router={}".format(
|
||||||
|
queue, self.llm_router
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
_healthy_deployments = await self.llm_router._async_get_healthy_deployments(
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
|
if len(_healthy_deployments) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the id is at the top of the heap
|
||||||
|
if queue[0][1] == id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_queue_status(self):
|
||||||
|
"""Get the status of items in the queue"""
|
||||||
|
return self.queue
|
||||||
|
|
||||||
|
async def get_queue(self, model_name: str) -> list:
|
||||||
|
"""
|
||||||
|
Return a queue for that specific model group
|
||||||
|
"""
|
||||||
|
if self.cache is not None:
|
||||||
|
_cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name)
|
||||||
|
response = await self.cache.async_get_cache(key=_cache_key)
|
||||||
|
if response is None or not isinstance(response, list):
|
||||||
|
return []
|
||||||
|
elif isinstance(response, list):
|
||||||
|
return response
|
||||||
|
return self.queue
|
||||||
|
|
||||||
|
async def save_queue(self, queue: list, model_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Save the updated queue of the model group
|
||||||
|
"""
|
||||||
|
if self.cache is not None:
|
||||||
|
_cache_key = "{}:{}".format(SchedulerCacheKeys.queue.value, model_name)
|
||||||
|
await self.cache.async_set_cache(key=_cache_key, value=queue)
|
||||||
|
return None
|
|
@ -134,7 +134,7 @@ async def test_router_retry_policy(error_type):
|
||||||
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
|
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
|
||||||
)
|
)
|
||||||
|
|
||||||
router = litellm.Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
@ -334,13 +334,13 @@ def test_retry_rate_limit_error_with_healthy_deployments():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_do_not_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployments():
|
def test_do_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployments():
|
||||||
"""
|
"""
|
||||||
Test 2. It SHOULD NOT Retry, when healthy_deployments is [] and fallbacks is None
|
Test 2. It SHOULD Retry, when healthy_deployments is [] and fallbacks is None
|
||||||
"""
|
"""
|
||||||
healthy_deployments = []
|
healthy_deployments = []
|
||||||
|
|
||||||
router = litellm.Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
@ -359,14 +359,14 @@ def test_do_not_retry_rate_limit_error_with_no_fallbacks_and_no_healthy_deployme
|
||||||
response = router.should_retry_this_error(
|
response = router.should_retry_this_error(
|
||||||
error=rate_limit_error, healthy_deployments=healthy_deployments
|
error=rate_limit_error, healthy_deployments=healthy_deployments
|
||||||
)
|
)
|
||||||
assert response != True, "Should have raised RateLimitError"
|
assert response == True
|
||||||
except openai.RateLimitError:
|
except Exception as e:
|
||||||
pass
|
pytest.fail("Should not have failed this error - {}".format(str(e)))
|
||||||
|
|
||||||
|
|
||||||
def test_raise_context_window_exceeded_error():
|
def test_raise_context_window_exceeded_error():
|
||||||
"""
|
"""
|
||||||
Retry Context Window Exceeded Error, when context_window_fallbacks is not None
|
Trigger Context Window fallback, when context_window_fallbacks is not None
|
||||||
"""
|
"""
|
||||||
context_window_error = litellm.ContextWindowExceededError(
|
context_window_error = litellm.ContextWindowExceededError(
|
||||||
message="Context window exceeded",
|
message="Context window exceeded",
|
||||||
|
@ -379,7 +379,7 @@ def test_raise_context_window_exceeded_error():
|
||||||
)
|
)
|
||||||
context_window_fallbacks = [{"gpt-3.5-turbo": ["azure/chatgpt-v-2"]}]
|
context_window_fallbacks = [{"gpt-3.5-turbo": ["azure/chatgpt-v-2"]}]
|
||||||
|
|
||||||
router = litellm.Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
@ -393,14 +393,17 @@ def test_raise_context_window_exceeded_error():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
response = router.should_retry_this_error(
|
try:
|
||||||
error=context_window_error,
|
response = router.should_retry_this_error(
|
||||||
healthy_deployments=None,
|
error=context_window_error,
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
healthy_deployments=None,
|
||||||
)
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
assert (
|
)
|
||||||
response == True
|
pytest.fail(
|
||||||
), "Should not have raised exception since we have context window fallbacks"
|
"Expected to raise context window exceeded error -> trigger fallback"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_raise_context_window_exceeded_error_no_retry():
|
def test_raise_context_window_exceeded_error_no_retry():
|
||||||
|
@ -418,7 +421,7 @@ def test_raise_context_window_exceeded_error_no_retry():
|
||||||
)
|
)
|
||||||
context_window_fallbacks = None
|
context_window_fallbacks = None
|
||||||
|
|
||||||
router = litellm.Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
@ -439,8 +442,8 @@ def test_raise_context_window_exceeded_error_no_retry():
|
||||||
context_window_fallbacks=context_window_fallbacks,
|
context_window_fallbacks=context_window_fallbacks,
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
response != True
|
response == True
|
||||||
), "Should have raised exception since we do not have context window fallbacks"
|
), "Should not have raised exception since we do not have context window fallbacks"
|
||||||
except litellm.ContextWindowExceededError:
|
except litellm.ContextWindowExceededError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
179
litellm/tests/test_scheduler.py
Normal file
179
litellm/tests/test_scheduler.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for the Scheduler.py (workload prioritization scheduler)
|
||||||
|
|
||||||
|
import sys, os, time, openai, uuid
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.scheduler import FlowItem, Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scheduler_diff_model_names():
|
||||||
|
"""
|
||||||
|
Assert 2 requests to 2 diff model groups are top of their respective queue's
|
||||||
|
"""
|
||||||
|
scheduler = Scheduler()
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.update_variables(llm_router=router)
|
||||||
|
|
||||||
|
item1 = FlowItem(priority=0, request_id="10", model_name="gpt-3.5-turbo")
|
||||||
|
item2 = FlowItem(priority=0, request_id="11", model_name="gpt-4")
|
||||||
|
await scheduler.add_request(item1)
|
||||||
|
await scheduler.add_request(item2)
|
||||||
|
|
||||||
|
assert await scheduler.poll(id="10", model_name="gpt-3.5-turbo") == True
|
||||||
|
assert await scheduler.poll(id="11", model_name="gpt-4") == True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("p0, p1", [(0, 0), (0, 1), (1, 0)])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scheduler_prioritized_requests(p0, p1):
|
||||||
|
"""
|
||||||
|
2 requests for same model group
|
||||||
|
"""
|
||||||
|
scheduler = Scheduler()
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.update_variables(llm_router=router)
|
||||||
|
|
||||||
|
item1 = FlowItem(priority=p0, request_id="10", model_name="gpt-3.5-turbo")
|
||||||
|
item2 = FlowItem(priority=p1, request_id="11", model_name="gpt-3.5-turbo")
|
||||||
|
await scheduler.add_request(item1)
|
||||||
|
await scheduler.add_request(item2)
|
||||||
|
|
||||||
|
if p0 == 0:
|
||||||
|
assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == True
|
||||||
|
assert await scheduler.peek(id="11", model_name="gpt-3.5-turbo") == False
|
||||||
|
else:
|
||||||
|
assert await scheduler.peek(id="11", model_name="gpt-3.5-turbo") == True
|
||||||
|
assert await scheduler.peek(id="10", model_name="gpt-3.5-turbo") == False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("p0, p1", [(0, 1)]) # (0, 0), (1, 0)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scheduler_prioritized_requests_mock_response(p0, p1):
|
||||||
|
"""
|
||||||
|
2 requests for same model group
|
||||||
|
|
||||||
|
if model is at rate limit, ensure the higher priority request gets done first
|
||||||
|
"""
|
||||||
|
scheduler = Scheduler()
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"mock_response": "Hello world this is Macintosh!",
|
||||||
|
"rpm": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
timeout=10,
|
||||||
|
num_retries=3,
|
||||||
|
cooldown_time=5,
|
||||||
|
routing_strategy="usage-based-routing-v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.update_variables(llm_router=router)
|
||||||
|
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _make_prioritized_call(flow_item: FlowItem):
|
||||||
|
## POLL QUEUE
|
||||||
|
default_timeout = router.timeout
|
||||||
|
end_time = time.time() + default_timeout
|
||||||
|
poll_interval = 0.03 # poll every 3ms
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
make_request = False
|
||||||
|
|
||||||
|
if router is None:
|
||||||
|
raise Exception("No llm router value")
|
||||||
|
|
||||||
|
while curr_time < end_time:
|
||||||
|
make_request = await scheduler.poll(
|
||||||
|
id=flow_item.request_id, model_name=flow_item.model_name
|
||||||
|
)
|
||||||
|
print(f"make_request={make_request}, priority={flow_item.priority}")
|
||||||
|
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||||
|
break
|
||||||
|
else: ## ELSE -> loop till default_timeout
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
if make_request:
|
||||||
|
try:
|
||||||
|
_response = await router.acompletion(
|
||||||
|
model=flow_item.model_name,
|
||||||
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("Received error - {}".format(str(e)))
|
||||||
|
return flow_item.priority, flow_item.request_id, time.time()
|
||||||
|
|
||||||
|
return flow_item.priority, flow_item.request_id, time.time()
|
||||||
|
|
||||||
|
raise Exception("didn't make request")
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
item = FlowItem(
|
||||||
|
priority=p0, request_id=str(uuid.uuid4()), model_name="gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
await scheduler.add_request(request=item)
|
||||||
|
tasks.append(_make_prioritized_call(flow_item=item))
|
||||||
|
|
||||||
|
item = FlowItem(
|
||||||
|
priority=p1, request_id=str(uuid.uuid4()), model_name="gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
await scheduler.add_request(request=item)
|
||||||
|
tasks.append(_make_prioritized_call(flow_item=item))
|
||||||
|
|
||||||
|
# Running the tasks and getting responses in order of completion
|
||||||
|
completed_responses = []
|
||||||
|
for task in asyncio.as_completed(tasks):
|
||||||
|
result = await task
|
||||||
|
completed_responses.append(result)
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
|
||||||
|
print(f"responses: {completed_responses}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
completed_responses[0][0] == 0
|
||||||
|
) # assert higher priority request got done first
|
||||||
|
assert (
|
||||||
|
completed_responses[0][2] < completed_responses[1][2]
|
||||||
|
) # higher priority request tried first
|
|
@ -314,6 +314,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
output_cost_per_token: Optional[float]
|
output_cost_per_token: Optional[float]
|
||||||
input_cost_per_second: Optional[float]
|
input_cost_per_second: Optional[float]
|
||||||
output_cost_per_second: Optional[float]
|
output_cost_per_second: Optional[float]
|
||||||
|
## MOCK RESPONSES ##
|
||||||
|
mock_response: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class DeploymentTypedDict(TypedDict):
|
class DeploymentTypedDict(TypedDict):
|
||||||
|
|
|
@ -103,6 +103,36 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def queue_chat_completion(
|
||||||
|
session, key, priority: int, model: Union[str, List] = "gpt-4"
|
||||||
|
):
|
||||||
|
url = "http://0.0.0.0:4000/queue/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
"priority": priority,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return response.raw_headers
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_with_headers(session, key, model="gpt-4"):
|
async def chat_completion_with_headers(session, key, model="gpt-4"):
|
||||||
url = "http://0.0.0.0:4000/chat/completions"
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue