mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
feat(router.py): support request prioritization for text completion c… (#7540)
* feat(router.py): support request prioritization for text completion calls * fix(internal_user_endpoints.py): fix sql query to return all keys, including null team id keys on `/user/info` Fixes https://github.com/BerriAI/litellm/issues/7485 * fix: fix linting errors * fix: fix linting error * test(test_router_helper_utils.py): add direct test for '_schedule_factory' Fixes code qa test
This commit is contained in:
parent
fb1272b46b
commit
db82b3bb2a
7 changed files with 229 additions and 3 deletions
|
@ -1252,3 +1252,78 @@ def test_get_model_group_info():
|
|||
model_group="openai/tts-1",
|
||||
)
|
||||
assert len(model_list) == 1
|
||||
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_team_data():
|
||||
return [
|
||||
{"team_id": "team1", "team_name": "Test Team 1"},
|
||||
{"team_id": "team2", "team_name": "Test Team 2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_key_data():
|
||||
return [
|
||||
{"token": "test_token_1", "key_name": "key1", "team_id": None, "spend": 0},
|
||||
{"token": "test_token_2", "key_name": "key2", "team_id": "team1", "spend": 100},
|
||||
{
|
||||
"token": "test_token_3",
|
||||
"key_name": "key3",
|
||||
"team_id": "litellm-dashboard",
|
||||
"spend": 50,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class MockDb:
|
||||
def __init__(self, mock_team_data, mock_key_data):
|
||||
self.mock_team_data = mock_team_data
|
||||
self.mock_key_data = mock_key_data
|
||||
|
||||
async def query_raw(self, query: str, *args):
|
||||
# Simulate the SQL query response
|
||||
filtered_keys = [
|
||||
k
|
||||
for k in self.mock_key_data
|
||||
if k["team_id"] != "litellm-dashboard" or k["team_id"] is None
|
||||
]
|
||||
|
||||
return [{"teams": self.mock_team_data, "keys": filtered_keys}]
|
||||
|
||||
|
||||
class MockPrismaClientDB:
|
||||
def __init__(
|
||||
self,
|
||||
mock_team_data,
|
||||
mock_key_data,
|
||||
):
|
||||
self.db = MockDb(mock_team_data, mock_key_data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
|
||||
# Patch the prisma_client import
|
||||
from litellm.proxy._types import UserInfoResponse
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
MockPrismaClientDB(mock_team_data, mock_key_data),
|
||||
):
|
||||
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
_get_user_info_for_proxy_admin,
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
result = await _get_user_info_for_proxy_admin()
|
||||
|
||||
# Verify the result structure
|
||||
assert isinstance(result, UserInfoResponse)
|
||||
assert len(result.keys) == 2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue