Litellm stable UI 02 17 2025 p1 (#8599)

* fix(key_management_endpoints.py): initial commit with logic to get all keys for teams user is an admin for

* fix(key_managements_endpoints.py): return all keys for teams user is an admin for

* fix(key_management_endpoints.py): add query param to ensure user opts into seeing all team keys (not just their own)

* fix(regenerate_key_modal.tsx): fix key regenerate

* fix(proxy_server.py): fix model metrics check on none api base

* test(test_key_generate_prisma.py): remove redundant test

* test(test_proxy_utils.py): add unit test covering new management endpoint helper util

* fix: fix test

* test(test_proxy_server.py): fix test
This commit is contained in:
Krish Dholakia 2025-02-17 17:55:05 -08:00 committed by GitHub
parent 9826f76288
commit 18bc9ddd3d
10 changed files with 277 additions and 121 deletions

View file

@ -0,0 +1,14 @@
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if (
member.user_id is not None and member.user_id == user_api_key_dict.user_id
) and member.role == "admin":
return True
return False

View file

@ -35,6 +35,7 @@ from litellm.proxy.auth.auth_checks import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import (
PrismaClient,
@ -1684,10 +1685,10 @@ async def validate_key_list_check(
organization_id: Optional[str],
key_alias: Optional[str],
prisma_client: PrismaClient,
):
) -> Optional[LiteLLM_UserTable]:
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return
return None
if user_api_key_dict.user_id is None:
raise ProxyException(
@ -1747,6 +1748,36 @@ async def validate_key_list_check(
param="organization_id",
code=status.HTTP_403_FORBIDDEN,
)
return complete_user_info
async def get_admin_team_ids(
complete_user_info: Optional[LiteLLM_UserTable],
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
) -> List[str]:
"""
Get all team IDs where the user is an admin.
"""
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
)
if teams is None:
return []
teams_pydantic_obj = [LiteLLM_TeamTable(**team.model_dump()) for team in teams]
admin_team_ids = [
team.team_id
for team in teams_pydantic_obj
if _is_user_team_admin(user_api_key_dict=user_api_key_dict, team_obj=team)
]
return admin_team_ids
@router.get(
@ -1767,6 +1798,9 @@ async def list_keys(
),
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
return_full_object: bool = Query(False, description="Return full key object"),
include_team_keys: bool = Query(
False, description="Include all keys for teams that user is an admin of."
),
) -> KeyListResponseObject:
"""
List all keys for a given user / team / organization.
@ -1782,32 +1816,13 @@ async def list_keys(
try:
from litellm.proxy.proxy_server import prisma_client
# Check for unsupported parameters
supported_params = {
"page",
"size",
"user_id",
"team_id",
"key_alias",
"return_full_object",
"organization_id",
}
unsupported_params = set(request.query_params.keys()) - supported_params
if unsupported_params:
raise ProxyException(
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
type=ProxyErrorTypes.bad_request_error,
param=", ".join(unsupported_params),
code=status.HTTP_400_BAD_REQUEST,
)
verbose_proxy_logger.debug("Entering list_keys function")
if prisma_client is None:
verbose_proxy_logger.error("Database not connected")
raise Exception("Database not connected")
await validate_key_list_check(
complete_user_info = await validate_key_list_check(
user_api_key_dict=user_api_key_dict,
user_id=user_id,
team_id=team_id,
@ -1816,6 +1831,15 @@ async def list_keys(
prisma_client=prisma_client,
)
if include_team_keys:
admin_team_ids = await get_admin_team_ids(
complete_user_info=complete_user_info,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
else:
admin_team_ids = None
if user_id is None and user_api_key_dict.user_role not in [
LitellmUserRoles.PROXY_ADMIN.value,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
@ -1831,6 +1855,7 @@ async def list_keys(
key_alias=key_alias,
return_full_object=return_full_object,
organization_id=organization_id,
admin_team_ids=admin_team_ids,
)
verbose_proxy_logger.debug("Successfully prepared response")
@ -1866,6 +1891,9 @@ async def _list_key_helper(
key_alias: Optional[str],
exclude_team_id: Optional[str] = None,
return_full_object: bool = False,
admin_team_ids: Optional[
List[str]
] = None, # New parameter for teams where user is admin
) -> KeyListResponseObject:
"""
Helper function to list keys
@ -1877,6 +1905,7 @@ async def _list_key_helper(
key_alias: Optional[str]
exclude_team_id: Optional[str] # exclude a specific team_id
return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token
admin_team_ids: Optional[List[str]] # list of team IDs where the user is an admin
Returns:
KeyListResponseObject
@ -1889,19 +1918,37 @@ async def _list_key_helper(
"""
# Prepare filter conditions
where: Dict[str, Union[str, Dict[str, Any]]] = {}
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
where.update(_get_condition_to_filter_out_ui_session_tokens())
# Build the OR conditions for user's keys and admin team keys
or_conditions: List[Dict[str, Any]] = []
# Base conditions for user's own keys
user_condition: Dict[str, Any] = {}
if user_id and isinstance(user_id, str):
where["user_id"] = user_id
user_condition["user_id"] = user_id
if team_id and isinstance(team_id, str):
where["team_id"] = team_id
user_condition["team_id"] = team_id
if key_alias and isinstance(key_alias, str):
where["key_alias"] = key_alias
user_condition["key_alias"] = key_alias
if exclude_team_id and isinstance(exclude_team_id, str):
where["team_id"] = {"not": exclude_team_id}
user_condition["team_id"] = {"not": exclude_team_id}
if organization_id and isinstance(organization_id, str):
where["organization_id"] = organization_id
user_condition["organization_id"] = organization_id
if user_condition:
or_conditions.append(user_condition)
# Add condition for admin team keys if provided
if admin_team_ids:
or_conditions.append({"team_id": {"in": admin_team_ids}})
# Combine conditions with OR if we have multiple conditions
if len(or_conditions) > 1:
where["OR"] = or_conditions
elif len(or_conditions) == 1:
where.update(or_conditions[0])
verbose_proxy_logger.debug(f"Filter conditions: {where}")

View file

@ -55,6 +55,7 @@ from litellm.proxy.auth.auth_checks import (
get_team_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_helpers.utils import (
add_new_member,
management_endpoint_wrapper,
@ -68,17 +69,6 @@ from litellm.proxy.utils import (
router = APIRouter()
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
return True
return False
def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool:
if litellm.default_internal_user_params is None:
return False

View file

@ -6472,9 +6472,9 @@ async def model_metrics(
if _day not in _daily_entries:
_daily_entries[_day] = {}
_combined_model_name = str(_model)
if "https://" in _api_base:
if _api_base is not None and "https://" in _api_base:
_combined_model_name = str(_api_base)
if "/openai/" in _combined_model_name:
if _combined_model_name is not None and "/openai/" in _combined_model_name:
_combined_model_name = _combined_model_name.split("/openai/")[0]
_all_api_bases.add(_combined_model_name)

View file

@ -3436,36 +3436,6 @@ async def test_list_keys(prisma_client):
assert _key in response["keys"]
@pytest.mark.asyncio
async def test_key_list_unsupported_params(prisma_client):
"""
Test the list_keys function:
- Test unsupported params
"""
from litellm.proxy.proxy_server import hash_token
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
try:
await list_keys(
request,
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
page=1,
size=10,
)
pytest.fail("Expected this call to fail")
except Exception as e:
print("error str=", str(e.message))
error_str = str(e.message)
assert "Unsupported parameter" in error_str
pass
@pytest.mark.asyncio
async def test_auth_vertex_ai_route(prisma_client):
"""

View file

@ -1232,6 +1232,7 @@ async def test_create_team_member_add_team_admin(
except HTTPException as e:
if user_role == "user":
assert e.status_code == 403
return
else:
raise e

View file

@ -1,7 +1,7 @@
import asyncio
import os
import sys
from typing import Any, Dict
from typing import Any, Dict, Optional, List
from unittest.mock import Mock
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
import json
@ -1618,6 +1618,10 @@ def test_provider_specific_header():
},
}
from litellm.proxy._types import LiteLLM_UserTable
@pytest.mark.parametrize(
"wildcard_model, expected_models",
[
@ -1643,6 +1647,7 @@ def test_get_known_models_from_wildcard(wildcard_model, expected_models):
assert all(model in wildcard_models for model in expected_models)
@pytest.mark.parametrize(
"data, user_api_key_dict, expected_model",
[
@ -1692,3 +1697,125 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod
# Check if model was updated correctly
assert test_data.get("model") == expected_model
@pytest.fixture
def mock_prisma_client():
client = MagicMock()
client.db = MagicMock()
client.db.litellm_teamtable = AsyncMock()
return client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_id, user_info, user_role, mock_teams, expected_teams, should_query_db",
[
("no_user_info", None, "proxy_admin", None, [], False),
(
"no_teams_found",
LiteLLM_UserTable(
teams=["team1", "team2"],
user_id="user1",
max_budget=100,
spend=0,
user_email="user1@example.com",
user_role="proxy_admin",
),
"proxy_admin",
None,
[],
True,
),
(
"admin_user_with_teams",
LiteLLM_UserTable(
teams=["team1", "team2"],
user_id="user1",
max_budget=100,
spend=0,
user_email="user1@example.com",
user_role="proxy_admin",
),
"proxy_admin",
[
MagicMock(
model_dump=lambda: {
"team_id": "team1",
"members_with_roles": [{"role": "admin", "user_id": "user1"}],
}
),
MagicMock(
model_dump=lambda: {
"team_id": "team2",
"members_with_roles": [
{"role": "admin", "user_id": "user1"},
{"role": "user", "user_id": "user2"},
],
}
),
],
["team1", "team2"],
True,
),
(
"non_admin_user",
LiteLLM_UserTable(
teams=["team1", "team2"],
user_id="user1",
max_budget=100,
spend=0,
user_email="user1@example.com",
user_role="internal_user",
),
"internal_user",
[
MagicMock(
model_dump=lambda: {"team_id": "team1", "members": ["user1"]}
),
MagicMock(
model_dump=lambda: {
"team_id": "team2",
"members": ["user1", "user2"],
}
),
],
[],
True,
),
],
)
async def test_get_admin_team_ids(
test_id: str,
user_info: Optional[LiteLLM_UserTable],
user_role: str,
mock_teams: Optional[List[MagicMock]],
expected_teams: List[str],
should_query_db: bool,
mock_prisma_client,
):
from litellm.proxy.management_endpoints.key_management_endpoints import (
get_admin_team_ids,
)
# Setup
mock_prisma_client.db.litellm_teamtable.find_many.return_value = mock_teams
user_api_key_dict = UserAPIKeyAuth(
user_role=user_role, user_id=user_info.user_id if user_info else None
)
# Execute
result = await get_admin_team_ids(
complete_user_info=user_info,
user_api_key_dict=user_api_key_dict,
prisma_client=mock_prisma_client,
)
# Assert
assert result == expected_teams, f"Expected {expected_teams}, but got {result}"
if should_query_db:
mock_prisma_client.db.litellm_teamtable.find_many.assert_called_once_with(
where={"team_id": {"in": user_info.teams}}
)
else:
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()

View file

@ -154,10 +154,6 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
visible={isRegenerateModalOpen}
onClose={() => setIsRegenerateModalOpen(false)}
accessToken={accessToken}
onSuccess={(newKeyData) => {
// Handle the updated key data here if needed
setIsRegenerateModalOpen(false);
}}
/>
{/* Delete Confirmation Modal */}

View file

@ -2176,6 +2176,7 @@ export const keyListCall = async (
}
queryParams.append('return_full_object', 'true');
queryParams.append('include_team_keys', 'true');
const queryString = queryParams.toString();
if (queryString) {

View file

@ -11,7 +11,6 @@ interface RegenerateKeyModalProps {
visible: boolean;
onClose: () => void;
accessToken: string | null;
onSuccess?: (newKeyData: any) => void;
}
export function RegenerateKeyModal({
@ -19,12 +18,12 @@ export function RegenerateKeyModal({
visible,
onClose,
accessToken,
onSuccess,
}: RegenerateKeyModalProps) {
const [form] = Form.useForm();
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
const [regenerateFormData, setRegenerateFormData] = useState<any>(null);
const [newExpiryTime, setNewExpiryTime] = useState<string | null>(null);
const [isRegenerating, setIsRegenerating] = useState(false);
useEffect(() => {
if (visible && selectedToken) {
@ -38,6 +37,15 @@ export function RegenerateKeyModal({
}
}, [visible, selectedToken, form]);
useEffect(() => {
if (!visible) {
// Reset states when modal is closed
setRegeneratedKey(null);
setIsRegenerating(false);
form.resetFields();
}
}, [visible, form]);
useEffect(() => {
const calculateNewExpiryTime = (duration: string | undefined) => {
if (!duration) return null;
@ -70,25 +78,24 @@ export function RegenerateKeyModal({
}, [regenerateFormData?.duration]);
const handleRegenerateKey = async () => {
if (!selectedToken || !accessToken) return;
setIsRegenerating(true);
try {
const formValues = await form.validateFields();
const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues);
setRegeneratedKey(response.key);
if (onSuccess) {
onSuccess({ ...selectedToken, key_name: response.key_name, ...formValues });
}
message.success("API Key regenerated successfully");
} catch (error) {
console.error("Error regenerating key:", error);
message.error("Failed to regenerate API Key");
setIsRegenerating(false); // Reset regenerating state on error
}
};
const handleClose = () => {
setRegeneratedKey(null);
setIsRegenerating(false);
form.resetFields();
onClose();
};
@ -96,7 +103,7 @@ export function RegenerateKeyModal({
return (
<Modal
title="Regenerate API Key"
visible={visible}
open={visible}
onCancel={handleClose}
footer={regeneratedKey ? [
<Button key="close" onClick={handleClose}>
@ -106,8 +113,12 @@ export function RegenerateKeyModal({
<Button key="cancel" onClick={handleClose} className="mr-2">
Cancel
</Button>,
<Button key="regenerate" onClick={handleRegenerateKey} >
Regenerate
<Button
key="regenerate"
onClick={handleRegenerateKey}
disabled={isRegenerating}
>
{isRegenerating ? "Regenerating..." : "Regenerate"}
</Button>,
]}
>
@ -175,8 +186,7 @@ export function RegenerateKeyModal({
</div>
)}
</Form>
)
}
)}
</Modal>
);
}