mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(router.py): support request prioritization for text completion c… (#7540)
* feat(router.py): support request prioritization for text completion calls * fix(internal_user_endpoints.py): fix sql query to return all keys, including null team id keys on `/user/info` Fixes https://github.com/BerriAI/litellm/issues/7485 * fix: fix linting errors * fix: fix linting error * test(test_router_helper_utils.py): add direct test for '_schedule_factory' Fixes code qa test
This commit is contained in:
parent
f770dd0c95
commit
d43d83f9ef
7 changed files with 229 additions and 3 deletions
|
@ -19,6 +19,11 @@ Prioritize LLM API requests in high-traffic.
|
||||||
- Priority - The lower the number, the higher the priority:
|
- Priority - The lower the number, the higher the priority:
|
||||||
* e.g. `priority=0` > `priority=2000`
|
* e.g. `priority=0` > `priority=2000`
|
||||||
|
|
||||||
|
Supported Router endpoints:
|
||||||
|
- `acompletion` (`/v1/chat/completions` on Proxy)
|
||||||
|
- `atext_completion` (`/v1/completions` on Proxy)
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -1928,6 +1928,10 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[AsyncOpenAI],
|
client: Optional[AsyncOpenAI],
|
||||||
|
order: Optional[str] = "desc",
|
||||||
|
limit: Optional[int] = 20,
|
||||||
|
before: Optional[str] = None,
|
||||||
|
after: Optional[str] = None,
|
||||||
) -> AsyncCursorPage[Assistant]:
|
) -> AsyncCursorPage[Assistant]:
|
||||||
openai_client = self.async_get_openai_client(
|
openai_client = self.async_get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -1937,8 +1941,16 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
request_params = {
|
||||||
|
"order": order,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
if before:
|
||||||
|
request_params["before"] = before
|
||||||
|
if after:
|
||||||
|
request_params["after"] = after
|
||||||
|
|
||||||
response = await openai_client.beta.assistants.list()
|
response = await openai_client.beta.assistants.list(**request_params) # type: ignore
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -1981,6 +1993,10 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client=None,
|
client=None,
|
||||||
aget_assistants=None,
|
aget_assistants=None,
|
||||||
|
order: Optional[str] = "desc",
|
||||||
|
limit: Optional[int] = 20,
|
||||||
|
before: Optional[str] = None,
|
||||||
|
after: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if aget_assistants is not None and aget_assistants is True:
|
if aget_assistants is not None and aget_assistants is True:
|
||||||
return self.async_get_assistants(
|
return self.async_get_assistants(
|
||||||
|
@ -2000,7 +2016,17 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = openai_client.beta.assistants.list()
|
request_params = {
|
||||||
|
"order": order,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
if before:
|
||||||
|
request_params["before"] = before
|
||||||
|
if after:
|
||||||
|
request_params["after"] = after
|
||||||
|
|
||||||
|
response = openai_client.beta.assistants.list(**request_params) # type: ignore
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,10 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-3.5-turbo
|
model: openai/gpt-3.5-turbo
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
- model_name: openai-text-completion
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-3.5-turbo
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: chatbot_actions
|
- model_name: chatbot_actions
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: langfuse/azure/gpt-4o
|
model: langfuse/azure/gpt-4o
|
||||||
|
@ -11,5 +15,6 @@ model_list:
|
||||||
tpm: 1000000
|
tpm: 1000000
|
||||||
prompt_id: "jokes"
|
prompt_id: "jokes"
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["otel"]
|
callbacks: ["otel"]
|
|
@ -404,7 +404,7 @@ async def _get_user_info_for_proxy_admin():
|
||||||
sql_query = """
|
sql_query = """
|
||||||
SELECT
|
SELECT
|
||||||
(SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams,
|
(SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams,
|
||||||
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard') as keys
|
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard' OR k.team_id IS NULL) as keys
|
||||||
"""
|
"""
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -413,6 +413,8 @@ async def _get_user_info_for_proxy_admin():
|
||||||
|
|
||||||
results = await prisma_client.db.query_raw(sql_query)
|
results = await prisma_client.db.query_raw(sql_query)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("results_keys: %s", results)
|
||||||
|
|
||||||
_keys_in_db: List = results[0]["keys"] or []
|
_keys_in_db: List = results[0]["keys"] or []
|
||||||
# cast all keys to LiteLLM_VerificationToken
|
# cast all keys to LiteLLM_VerificationToken
|
||||||
keys_in_db = []
|
keys_in_db = []
|
||||||
|
|
|
@ -1356,6 +1356,67 @@ class Router:
|
||||||
llm_provider="openai",
|
llm_provider="openai",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _schedule_factory(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
priority: int,
|
||||||
|
original_function: Callable,
|
||||||
|
args: Tuple[Any, ...],
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
):
|
||||||
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||||
|
### FLOW ITEM ###
|
||||||
|
_request_id = str(uuid.uuid4())
|
||||||
|
item = FlowItem(
|
||||||
|
priority=priority, # 👈 SET PRIORITY FOR REQUEST
|
||||||
|
request_id=_request_id, # 👈 SET REQUEST ID
|
||||||
|
model_name=model, # 👈 SAME as 'Router'
|
||||||
|
)
|
||||||
|
### [fin] ###
|
||||||
|
|
||||||
|
## ADDS REQUEST TO QUEUE ##
|
||||||
|
await self.scheduler.add_request(request=item)
|
||||||
|
|
||||||
|
## POLL QUEUE
|
||||||
|
end_time = time.time() + self.timeout
|
||||||
|
curr_time = time.time()
|
||||||
|
poll_interval = self.scheduler.polling_interval # poll every 3ms
|
||||||
|
make_request = False
|
||||||
|
|
||||||
|
while curr_time < end_time:
|
||||||
|
_healthy_deployments, _ = await self._async_get_healthy_deployments(
|
||||||
|
model=model, parent_otel_span=parent_otel_span
|
||||||
|
)
|
||||||
|
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
||||||
|
id=item.request_id,
|
||||||
|
model_name=item.model_name,
|
||||||
|
health_deployments=_healthy_deployments,
|
||||||
|
)
|
||||||
|
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||||
|
break
|
||||||
|
else: ## ELSE -> loop till default_timeout
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
curr_time = time.time()
|
||||||
|
|
||||||
|
if make_request:
|
||||||
|
try:
|
||||||
|
_response = await original_function(*args, **kwargs)
|
||||||
|
if isinstance(_response._hidden_params, dict):
|
||||||
|
_response._hidden_params.setdefault("additional_headers", {})
|
||||||
|
_response._hidden_params["additional_headers"].update(
|
||||||
|
{"x-litellm-request-prioritization-used": True}
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
except Exception as e:
|
||||||
|
setattr(e, "priority", priority)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
raise litellm.Timeout(
|
||||||
|
message="Request timed out while polling queue",
|
||||||
|
model=model,
|
||||||
|
llm_provider="openai",
|
||||||
|
)
|
||||||
|
|
||||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
@ -1844,10 +1905,19 @@ class Router:
|
||||||
is_async: Optional[bool] = False,
|
is_async: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if kwargs.get("priority", None) is not None:
|
||||||
|
return await self._schedule_factory(
|
||||||
|
model=model,
|
||||||
|
priority=kwargs.pop("priority"),
|
||||||
|
original_function=self.atext_completion,
|
||||||
|
args=(model, prompt),
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["original_function"] = self._atext_completion
|
kwargs["original_function"] = self._atext_completion
|
||||||
|
|
||||||
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1252,3 +1252,78 @@ def test_get_model_group_info():
|
||||||
model_group="openai/tts-1",
|
model_group="openai/tts-1",
|
||||||
)
|
)
|
||||||
assert len(model_list) == 1
|
assert len(model_list) == 1
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_team_data():
|
||||||
|
return [
|
||||||
|
{"team_id": "team1", "team_name": "Test Team 1"},
|
||||||
|
{"team_id": "team2", "team_name": "Test Team 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_key_data():
|
||||||
|
return [
|
||||||
|
{"token": "test_token_1", "key_name": "key1", "team_id": None, "spend": 0},
|
||||||
|
{"token": "test_token_2", "key_name": "key2", "team_id": "team1", "spend": 100},
|
||||||
|
{
|
||||||
|
"token": "test_token_3",
|
||||||
|
"key_name": "key3",
|
||||||
|
"team_id": "litellm-dashboard",
|
||||||
|
"spend": 50,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MockDb:
|
||||||
|
def __init__(self, mock_team_data, mock_key_data):
|
||||||
|
self.mock_team_data = mock_team_data
|
||||||
|
self.mock_key_data = mock_key_data
|
||||||
|
|
||||||
|
async def query_raw(self, query: str, *args):
|
||||||
|
# Simulate the SQL query response
|
||||||
|
filtered_keys = [
|
||||||
|
k
|
||||||
|
for k in self.mock_key_data
|
||||||
|
if k["team_id"] != "litellm-dashboard" or k["team_id"] is None
|
||||||
|
]
|
||||||
|
|
||||||
|
return [{"teams": self.mock_team_data, "keys": filtered_keys}]
|
||||||
|
|
||||||
|
|
||||||
|
class MockPrismaClientDB:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mock_team_data,
|
||||||
|
mock_key_data,
|
||||||
|
):
|
||||||
|
self.db = MockDb(mock_team_data, mock_key_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
|
||||||
|
# Patch the prisma_client import
|
||||||
|
from litellm.proxy._types import UserInfoResponse
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.prisma_client",
|
||||||
|
MockPrismaClientDB(mock_team_data, mock_key_data),
|
||||||
|
):
|
||||||
|
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||||
|
_get_user_info_for_proxy_admin,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the function
|
||||||
|
result = await _get_user_info_for_proxy_admin()
|
||||||
|
|
||||||
|
# Verify the result structure
|
||||||
|
assert isinstance(result, UserInfoResponse)
|
||||||
|
assert len(result.keys) == 2
|
||||||
|
|
|
@ -166,6 +166,49 @@ async def test_router_schedule_acompletion(model_list):
|
||||||
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!"
|
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_schedule_atext_completion(model_list):
|
||||||
|
"""Test if the 'schedule_atext_completion' function is working correctly"""
|
||||||
|
from litellm.types.utils import TextCompletionResponse
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
with patch.object(
|
||||||
|
router, "_atext_completion", AsyncMock()
|
||||||
|
) as mock_atext_completion:
|
||||||
|
mock_atext_completion.return_value = TextCompletionResponse()
|
||||||
|
response = await router.atext_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
prompt="Hello, how are you?",
|
||||||
|
priority=1,
|
||||||
|
)
|
||||||
|
mock_atext_completion.assert_awaited_once()
|
||||||
|
assert "priority" not in mock_atext_completion.call_args.kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_schedule_factory(model_list):
|
||||||
|
"""Test if the 'schedule_atext_completion' function is working correctly"""
|
||||||
|
from litellm.types.utils import TextCompletionResponse
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
with patch.object(
|
||||||
|
router, "_atext_completion", AsyncMock()
|
||||||
|
) as mock_atext_completion:
|
||||||
|
mock_atext_completion.return_value = TextCompletionResponse()
|
||||||
|
response = await router._schedule_factory(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
args=(
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"Hello, how are you?",
|
||||||
|
),
|
||||||
|
priority=1,
|
||||||
|
kwargs={},
|
||||||
|
original_function=router.atext_completion,
|
||||||
|
)
|
||||||
|
mock_atext_completion.assert_awaited_once()
|
||||||
|
assert "priority" not in mock_atext_completion.call_args.kwargs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_arealtime(model_list):
|
async def test_router_arealtime(model_list):
|
||||||
"""Test if the '_arealtime' function is working correctly"""
|
"""Test if the '_arealtime' function is working correctly"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue