From db82b3bb2a3b40011a027f2ce2d6be071f16a366 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 3 Jan 2025 19:35:44 -0800 Subject: [PATCH] =?UTF-8?q?feat(router.py):=20support=20request=20prioriti?= =?UTF-8?q?zation=20for=20text=20completion=20c=E2=80=A6=20(#7540)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(router.py): support request prioritization for text completion calls * fix(internal_user_endpoints.py): fix sql query to return all keys, including null team id keys on `/user/info` Fixes https://github.com/BerriAI/litellm/issues/7485 * fix: fix linting errors * fix: fix linting error * test(test_router_helper_utils.py): add direct test for '_schedule_factory' Fixes code qa test --- docs/my-website/docs/scheduler.md | 5 ++ litellm/llms/openai/openai.py | 30 +++++++- litellm/proxy/_new_secret_config.yaml | 5 ++ .../internal_user_endpoints.py | 4 +- litellm/router.py | 70 +++++++++++++++++ tests/proxy_unit_tests/test_proxy_utils.py | 75 +++++++++++++++++++ .../test_router_helper_utils.py | 43 +++++++++++ 7 files changed, 229 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/scheduler.md b/docs/my-website/docs/scheduler.md index e59b03eacb..2b0a582626 100644 --- a/docs/my-website/docs/scheduler.md +++ b/docs/my-website/docs/scheduler.md @@ -19,6 +19,11 @@ Prioritize LLM API requests in high-traffic. - Priority - The lower the number, the higher the priority: * e.g. `priority=0` > `priority=2000` +Supported Router endpoints: +- `acompletion` (`/v1/chat/completions` on Proxy) +- `atext_completion` (`/v1/completions` on Proxy) + + ## Quick Start ```python diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index a7ab3a72e0..ff40fa1853 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -1928,6 +1928,10 @@ class OpenAIAssistantsAPI(BaseLLM): max_retries: Optional[int], organization: Optional[str], client: Optional[AsyncOpenAI], + order: Optional[str] = "desc", + limit: Optional[int] = 20, + before: Optional[str] = None, + after: Optional[str] = None, ) -> AsyncCursorPage[Assistant]: openai_client = self.async_get_openai_client( api_key=api_key, @@ -1937,8 +1941,16 @@ class OpenAIAssistantsAPI(BaseLLM): organization=organization, client=client, ) + request_params = { + "order": order, + "limit": limit, + } + if before: + request_params["before"] = before + if after: + request_params["after"] = after - response = await openai_client.beta.assistants.list() + response = await openai_client.beta.assistants.list(**request_params) # type: ignore return response @@ -1981,6 +1993,10 @@ class OpenAIAssistantsAPI(BaseLLM): organization: Optional[str], client=None, aget_assistants=None, + order: Optional[str] = "desc", + limit: Optional[int] = 20, + before: Optional[str] = None, + after: Optional[str] = None, ): if aget_assistants is not None and aget_assistants is True: return self.async_get_assistants( @@ -2000,7 +2016,17 @@ class OpenAIAssistantsAPI(BaseLLM): client=client, ) - response = openai_client.beta.assistants.list() + request_params = { + "order": order, + "limit": limit, + } + + if before: + request_params["before"] = before + if after: + request_params["after"] = after + + response = openai_client.beta.assistants.list(**request_params) # type: ignore return response diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 61dc6774c0..7f135f8b79 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,6 +3,10 @@ model_list: litellm_params: model: openai/gpt-3.5-turbo api_key: os.environ/OPENAI_API_KEY + - model_name: openai-text-completion + litellm_params: + model: openai/gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY - model_name: chatbot_actions litellm_params: model: langfuse/azure/gpt-4o @@ -11,5 +15,6 @@ model_list: tpm: 1000000 prompt_id: "jokes" + litellm_settings: callbacks: ["otel"] \ No newline at end of file diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 5ca4a6251f..e4a1740ba5 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -404,7 +404,7 @@ async def _get_user_info_for_proxy_admin(): sql_query = """ SELECT (SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams, - (SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard') as keys + (SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard' OR k.team_id IS NULL) as keys """ if prisma_client is None: raise Exception( @@ -413,6 +413,8 @@ async def _get_user_info_for_proxy_admin(): results = await prisma_client.db.query_raw(sql_query) + verbose_proxy_logger.debug("results_keys: %s", results) + _keys_in_db: List = results[0]["keys"] or [] # cast all keys to LiteLLM_VerificationToken keys_in_db = [] diff --git a/litellm/router.py b/litellm/router.py index ad36ebb13d..d837b203ad 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1356,6 +1356,67 @@ class Router: llm_provider="openai", ) + async def _schedule_factory( + self, + model: str, + priority: int, + original_function: Callable, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ): + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + ### FLOW ITEM ### + _request_id = str(uuid.uuid4()) + item = FlowItem( + priority=priority, # 👈 SET PRIORITY FOR REQUEST + request_id=_request_id, # 👈 SET REQUEST ID + model_name=model, # 👈 SAME as 'Router' + ) + ### [fin] ### + + ## ADDS REQUEST TO QUEUE ## + await self.scheduler.add_request(request=item) + + ## POLL QUEUE + end_time = time.time() + self.timeout + curr_time = time.time() + poll_interval = self.scheduler.polling_interval # poll every 3ms + make_request = False + + while curr_time < end_time: + _healthy_deployments, _ = await self._async_get_healthy_deployments( + model=model, parent_otel_span=parent_otel_span + ) + make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue + id=item.request_id, + model_name=item.model_name, + health_deployments=_healthy_deployments, + ) + if make_request: ## IF TRUE -> MAKE REQUEST + break + else: ## ELSE -> loop till default_timeout + await asyncio.sleep(poll_interval) + curr_time = time.time() + + if make_request: + try: + _response = await original_function(*args, **kwargs) + if isinstance(_response._hidden_params, dict): + _response._hidden_params.setdefault("additional_headers", {}) + _response._hidden_params["additional_headers"].update( + {"x-litellm-request-prioritization-used": True} + ) + return _response + except Exception as e: + setattr(e, "priority", priority) + raise e + else: + raise litellm.Timeout( + message="Request timed out while polling queue", + model=model, + llm_provider="openai", + ) + def image_generation(self, prompt: str, model: str, **kwargs): try: kwargs["model"] = model @@ -1844,10 +1905,19 @@ class Router: is_async: Optional[bool] = False, **kwargs, ): + if kwargs.get("priority", None) is not None: + return await self._schedule_factory( + model=model, + priority=kwargs.pop("priority"), + original_function=self.atext_completion, + args=(model, prompt), + kwargs=kwargs, + ) try: kwargs["model"] = model kwargs["prompt"] = prompt kwargs["original_function"] = self._atext_completion + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index d1b797b34a..bed171df70 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1252,3 +1252,78 @@ def test_get_model_group_info(): model_group="openai/tts-1", ) assert len(model_list) == 1 + + +import pytest +import asyncio +from unittest.mock import AsyncMock, patch +import json + + +@pytest.fixture +def mock_team_data(): + return [ + {"team_id": "team1", "team_name": "Test Team 1"}, + {"team_id": "team2", "team_name": "Test Team 2"}, + ] + + +@pytest.fixture +def mock_key_data(): + return [ + {"token": "test_token_1", "key_name": "key1", "team_id": None, "spend": 0}, + {"token": "test_token_2", "key_name": "key2", "team_id": "team1", "spend": 100}, + { + "token": "test_token_3", + "key_name": "key3", + "team_id": "litellm-dashboard", + "spend": 50, + }, + ] + + +class MockDb: + def __init__(self, mock_team_data, mock_key_data): + self.mock_team_data = mock_team_data + self.mock_key_data = mock_key_data + + async def query_raw(self, query: str, *args): + # Simulate the SQL query response + filtered_keys = [ + k + for k in self.mock_key_data + if k["team_id"] != "litellm-dashboard" or k["team_id"] is None + ] + + return [{"teams": self.mock_team_data, "keys": filtered_keys}] + + +class MockPrismaClientDB: + def __init__( + self, + mock_team_data, + mock_key_data, + ): + self.db = MockDb(mock_team_data, mock_key_data) + + +@pytest.mark.asyncio +async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): + # Patch the prisma_client import + from litellm.proxy._types import UserInfoResponse + + with patch( + "litellm.proxy.proxy_server.prisma_client", + MockPrismaClientDB(mock_team_data, mock_key_data), + ): + + from litellm.proxy.management_endpoints.internal_user_endpoints import ( + _get_user_info_for_proxy_admin, + ) + + # Execute the function + result = await _get_user_info_for_proxy_admin() + + # Verify the result structure + assert isinstance(result, UserInfoResponse) + assert len(result.keys) == 2 diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 2d0c702d58..e3ca281508 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -166,6 +166,49 @@ async def test_router_schedule_acompletion(model_list): assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!" +@pytest.mark.asyncio +async def test_router_schedule_atext_completion(model_list): + """Test if the 'schedule_atext_completion' function is working correctly""" + from litellm.types.utils import TextCompletionResponse + + router = Router(model_list=model_list) + with patch.object( + router, "_atext_completion", AsyncMock() + ) as mock_atext_completion: + mock_atext_completion.return_value = TextCompletionResponse() + response = await router.atext_completion( + model="gpt-3.5-turbo", + prompt="Hello, how are you?", + priority=1, + ) + mock_atext_completion.assert_awaited_once() + assert "priority" not in mock_atext_completion.call_args.kwargs + + +@pytest.mark.asyncio +async def test_router_schedule_factory(model_list): + """Test if the 'schedule_atext_completion' function is working correctly""" + from litellm.types.utils import TextCompletionResponse + + router = Router(model_list=model_list) + with patch.object( + router, "_atext_completion", AsyncMock() + ) as mock_atext_completion: + mock_atext_completion.return_value = TextCompletionResponse() + response = await router._schedule_factory( + model="gpt-3.5-turbo", + args=( + "gpt-3.5-turbo", + "Hello, how are you?", + ), + priority=1, + kwargs={}, + original_function=router.atext_completion, + ) + mock_atext_completion.assert_awaited_once() + assert "priority" not in mock_atext_completion.call_args.kwargs + + @pytest.mark.asyncio async def test_router_arealtime(model_list): """Test if the '_arealtime' function is working correctly"""