Litellm stable UI 02 17 2025 p1 (#8599)

* fix(key_management_endpoints.py): initial commit with logic to get all keys for teams user is an admin for

* fix(key_managements_endpoints.py): return all keys for teams user is an admin for

* fix(key_management_endpoints.py): add query param to ensure user opts into seeing all team keys (not just their own)

* fix(regenerate_key_modal.tsx): fix key regenerate

* fix(proxy_server.py): fix model metrics check on none api base

* test(test_key_generate_prisma.py): remove redundant test

* test(test_proxy_utils.py): add unit test covering new management endpoint helper util

* fix: fix test

* test(test_proxy_server.py): fix test
This commit is contained in:
Krish Dholakia 2025-02-17 17:55:05 -08:00 committed by GitHub
parent 9826f76288
commit 18bc9ddd3d
10 changed files with 277 additions and 121 deletions

View file

@ -0,0 +1,14 @@
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if (
member.user_id is not None and member.user_id == user_api_key_dict.user_id
) and member.role == "admin":
return True
return False

View file

@ -35,6 +35,7 @@ from litellm.proxy.auth.auth_checks import (
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
@ -1684,10 +1685,10 @@ async def validate_key_list_check(
organization_id: Optional[str], organization_id: Optional[str],
key_alias: Optional[str], key_alias: Optional[str],
prisma_client: PrismaClient, prisma_client: PrismaClient,
): ) -> Optional[LiteLLM_UserTable]:
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
return return None
if user_api_key_dict.user_id is None: if user_api_key_dict.user_id is None:
raise ProxyException( raise ProxyException(
@ -1747,6 +1748,36 @@ async def validate_key_list_check(
param="organization_id", param="organization_id",
code=status.HTTP_403_FORBIDDEN, code=status.HTTP_403_FORBIDDEN,
) )
return complete_user_info
async def get_admin_team_ids(
complete_user_info: Optional[LiteLLM_UserTable],
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
) -> List[str]:
"""
Get all team IDs where the user is an admin.
"""
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
)
if teams is None:
return []
teams_pydantic_obj = [LiteLLM_TeamTable(**team.model_dump()) for team in teams]
admin_team_ids = [
team.team_id
for team in teams_pydantic_obj
if _is_user_team_admin(user_api_key_dict=user_api_key_dict, team_obj=team)
]
return admin_team_ids
@router.get( @router.get(
@ -1767,6 +1798,9 @@ async def list_keys(
), ),
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"), key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
return_full_object: bool = Query(False, description="Return full key object"), return_full_object: bool = Query(False, description="Return full key object"),
include_team_keys: bool = Query(
False, description="Include all keys for teams that user is an admin of."
),
) -> KeyListResponseObject: ) -> KeyListResponseObject:
""" """
List all keys for a given user / team / organization. List all keys for a given user / team / organization.
@ -1782,32 +1816,13 @@ async def list_keys(
try: try:
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import prisma_client
# Check for unsupported parameters
supported_params = {
"page",
"size",
"user_id",
"team_id",
"key_alias",
"return_full_object",
"organization_id",
}
unsupported_params = set(request.query_params.keys()) - supported_params
if unsupported_params:
raise ProxyException(
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
type=ProxyErrorTypes.bad_request_error,
param=", ".join(unsupported_params),
code=status.HTTP_400_BAD_REQUEST,
)
verbose_proxy_logger.debug("Entering list_keys function") verbose_proxy_logger.debug("Entering list_keys function")
if prisma_client is None: if prisma_client is None:
verbose_proxy_logger.error("Database not connected") verbose_proxy_logger.error("Database not connected")
raise Exception("Database not connected") raise Exception("Database not connected")
await validate_key_list_check( complete_user_info = await validate_key_list_check(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
user_id=user_id, user_id=user_id,
team_id=team_id, team_id=team_id,
@ -1816,6 +1831,15 @@ async def list_keys(
prisma_client=prisma_client, prisma_client=prisma_client,
) )
if include_team_keys:
admin_team_ids = await get_admin_team_ids(
complete_user_info=complete_user_info,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
)
else:
admin_team_ids = None
if user_id is None and user_api_key_dict.user_role not in [ if user_id is None and user_api_key_dict.user_role not in [
LitellmUserRoles.PROXY_ADMIN.value, LitellmUserRoles.PROXY_ADMIN.value,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value, LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
@ -1831,6 +1855,7 @@ async def list_keys(
key_alias=key_alias, key_alias=key_alias,
return_full_object=return_full_object, return_full_object=return_full_object,
organization_id=organization_id, organization_id=organization_id,
admin_team_ids=admin_team_ids,
) )
verbose_proxy_logger.debug("Successfully prepared response") verbose_proxy_logger.debug("Successfully prepared response")
@ -1866,6 +1891,9 @@ async def _list_key_helper(
key_alias: Optional[str], key_alias: Optional[str],
exclude_team_id: Optional[str] = None, exclude_team_id: Optional[str] = None,
return_full_object: bool = False, return_full_object: bool = False,
admin_team_ids: Optional[
List[str]
] = None, # New parameter for teams where user is admin
) -> KeyListResponseObject: ) -> KeyListResponseObject:
""" """
Helper function to list keys Helper function to list keys
@ -1877,6 +1905,7 @@ async def _list_key_helper(
key_alias: Optional[str] key_alias: Optional[str]
exclude_team_id: Optional[str] # exclude a specific team_id exclude_team_id: Optional[str] # exclude a specific team_id
return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token
admin_team_ids: Optional[List[str]] # list of team IDs where the user is an admin
Returns: Returns:
KeyListResponseObject KeyListResponseObject
@ -1889,19 +1918,37 @@ async def _list_key_helper(
""" """
# Prepare filter conditions # Prepare filter conditions
where: Dict[str, Union[str, Dict[str, Any]]] = {} where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
where.update(_get_condition_to_filter_out_ui_session_tokens()) where.update(_get_condition_to_filter_out_ui_session_tokens())
# Build the OR conditions for user's keys and admin team keys
or_conditions: List[Dict[str, Any]] = []
# Base conditions for user's own keys
user_condition: Dict[str, Any] = {}
if user_id and isinstance(user_id, str): if user_id and isinstance(user_id, str):
where["user_id"] = user_id user_condition["user_id"] = user_id
if team_id and isinstance(team_id, str): if team_id and isinstance(team_id, str):
where["team_id"] = team_id user_condition["team_id"] = team_id
if key_alias and isinstance(key_alias, str): if key_alias and isinstance(key_alias, str):
where["key_alias"] = key_alias user_condition["key_alias"] = key_alias
if exclude_team_id and isinstance(exclude_team_id, str): if exclude_team_id and isinstance(exclude_team_id, str):
where["team_id"] = {"not": exclude_team_id} user_condition["team_id"] = {"not": exclude_team_id}
if organization_id and isinstance(organization_id, str): if organization_id and isinstance(organization_id, str):
where["organization_id"] = organization_id user_condition["organization_id"] = organization_id
if user_condition:
or_conditions.append(user_condition)
# Add condition for admin team keys if provided
if admin_team_ids:
or_conditions.append({"team_id": {"in": admin_team_ids}})
# Combine conditions with OR if we have multiple conditions
if len(or_conditions) > 1:
where["OR"] = or_conditions
elif len(or_conditions) == 1:
where.update(or_conditions[0])
verbose_proxy_logger.debug(f"Filter conditions: {where}") verbose_proxy_logger.debug(f"Filter conditions: {where}")

View file

@ -55,6 +55,7 @@ from litellm.proxy.auth.auth_checks import (
get_team_object, get_team_object,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
@ -68,17 +69,6 @@ from litellm.proxy.utils import (
router = APIRouter() router = APIRouter()
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
return True
return False
def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool: def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool:
if litellm.default_internal_user_params is None: if litellm.default_internal_user_params is None:
return False return False

View file

@ -6472,9 +6472,9 @@ async def model_metrics(
if _day not in _daily_entries: if _day not in _daily_entries:
_daily_entries[_day] = {} _daily_entries[_day] = {}
_combined_model_name = str(_model) _combined_model_name = str(_model)
if "https://" in _api_base: if _api_base is not None and "https://" in _api_base:
_combined_model_name = str(_api_base) _combined_model_name = str(_api_base)
if "/openai/" in _combined_model_name: if _combined_model_name is not None and "/openai/" in _combined_model_name:
_combined_model_name = _combined_model_name.split("/openai/")[0] _combined_model_name = _combined_model_name.split("/openai/")[0]
_all_api_bases.add(_combined_model_name) _all_api_bases.add(_combined_model_name)

View file

@ -3436,36 +3436,6 @@ async def test_list_keys(prisma_client):
assert _key in response["keys"] assert _key in response["keys"]
@pytest.mark.asyncio
async def test_key_list_unsupported_params(prisma_client):
"""
Test the list_keys function:
- Test unsupported params
"""
from litellm.proxy.proxy_server import hash_token
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
try:
await list_keys(
request,
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
page=1,
size=10,
)
pytest.fail("Expected this call to fail")
except Exception as e:
print("error str=", str(e.message))
error_str = str(e.message)
assert "Unsupported parameter" in error_str
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auth_vertex_ai_route(prisma_client): async def test_auth_vertex_ai_route(prisma_client):
""" """

View file

@ -1232,6 +1232,7 @@ async def test_create_team_member_add_team_admin(
except HTTPException as e: except HTTPException as e:
if user_role == "user": if user_role == "user":
assert e.status_code == 403 assert e.status_code == 403
return
else: else:
raise e raise e

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import os import os
import sys import sys
from typing import Any, Dict from typing import Any, Dict, Optional, List
from unittest.mock import Mock from unittest.mock import Mock
from litellm.proxy.utils import _get_redoc_url, _get_docs_url from litellm.proxy.utils import _get_redoc_url, _get_docs_url
import json import json
@ -1618,6 +1618,10 @@ def test_provider_specific_header():
}, },
} }
from litellm.proxy._types import LiteLLM_UserTable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wildcard_model, expected_models", "wildcard_model, expected_models",
[ [
@ -1642,7 +1646,8 @@ def test_get_known_models_from_wildcard(wildcard_model, expected_models):
print(f"Missing expected model: {model}") print(f"Missing expected model: {model}")
assert all(model in wildcard_models for model in expected_models) assert all(model in wildcard_models for model in expected_models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"data, user_api_key_dict, expected_model", "data, user_api_key_dict, expected_model",
[ [
@ -1692,3 +1697,125 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod
# Check if model was updated correctly # Check if model was updated correctly
assert test_data.get("model") == expected_model assert test_data.get("model") == expected_model
@pytest.fixture
def mock_prisma_client():
client = MagicMock()
client.db = MagicMock()
client.db.litellm_teamtable = AsyncMock()
return client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_id, user_info, user_role, mock_teams, expected_teams, should_query_db",
[
("no_user_info", None, "proxy_admin", None, [], False),
(
"no_teams_found",
LiteLLM_UserTable(
teams=["team1", "team2"],
user_id="user1",
max_budget=100,
spend=0,
user_email="user1@example.com",
user_role="proxy_admin",
),
"proxy_admin",
None,
[],
True,
),
(
"admin_user_with_teams",
LiteLLM_UserTable(
teams=["team1", "team2"],
user_id="user1",
max_budget=100,
spend=0,
user_email="user1@example.com",
user_role="proxy_admin",
),
"proxy_admin",
[
MagicMock(
model_dump=lambda: {
"team_id": "team1",
"members_with_roles": [{"role": "admin", "user_id": "user1"}],
}
),
MagicMock(
model_dump=lambda: {
"team_id": "team2",
"members_with_roles": [
{"role": "admin", "user_id": "user1"},
{"role": "user", "user_id": "user2"},
],
}
),
],
["team1", "team2"],
True,
),
(
"non_admin_user",
LiteLLM_UserTable(
teams=["team1", "team2"],
user_id="user1",
max_budget=100,
spend=0,
user_email="user1@example.com",
user_role="internal_user",
),
"internal_user",
[
MagicMock(
model_dump=lambda: {"team_id": "team1", "members": ["user1"]}
),
MagicMock(
model_dump=lambda: {
"team_id": "team2",
"members": ["user1", "user2"],
}
),
],
[],
True,
),
],
)
async def test_get_admin_team_ids(
test_id: str,
user_info: Optional[LiteLLM_UserTable],
user_role: str,
mock_teams: Optional[List[MagicMock]],
expected_teams: List[str],
should_query_db: bool,
mock_prisma_client,
):
from litellm.proxy.management_endpoints.key_management_endpoints import (
get_admin_team_ids,
)
# Setup
mock_prisma_client.db.litellm_teamtable.find_many.return_value = mock_teams
user_api_key_dict = UserAPIKeyAuth(
user_role=user_role, user_id=user_info.user_id if user_info else None
)
# Execute
result = await get_admin_team_ids(
complete_user_info=user_info,
user_api_key_dict=user_api_key_dict,
prisma_client=mock_prisma_client,
)
# Assert
assert result == expected_teams, f"Expected {expected_teams}, but got {result}"
if should_query_db:
mock_prisma_client.db.litellm_teamtable.find_many.assert_called_once_with(
where={"team_id": {"in": user_info.teams}}
)
else:
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()

View file

@ -154,10 +154,6 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
visible={isRegenerateModalOpen} visible={isRegenerateModalOpen}
onClose={() => setIsRegenerateModalOpen(false)} onClose={() => setIsRegenerateModalOpen(false)}
accessToken={accessToken} accessToken={accessToken}
onSuccess={(newKeyData) => {
// Handle the updated key data here if needed
setIsRegenerateModalOpen(false);
}}
/> />
{/* Delete Confirmation Modal */} {/* Delete Confirmation Modal */}

View file

@ -2176,6 +2176,7 @@ export const keyListCall = async (
} }
queryParams.append('return_full_object', 'true'); queryParams.append('return_full_object', 'true');
queryParams.append('include_team_keys', 'true');
const queryString = queryParams.toString(); const queryString = queryParams.toString();
if (queryString) { if (queryString) {

View file

@ -11,7 +11,6 @@ interface RegenerateKeyModalProps {
visible: boolean; visible: boolean;
onClose: () => void; onClose: () => void;
accessToken: string | null; accessToken: string | null;
onSuccess?: (newKeyData: any) => void;
} }
export function RegenerateKeyModal({ export function RegenerateKeyModal({
@ -19,12 +18,12 @@ export function RegenerateKeyModal({
visible, visible,
onClose, onClose,
accessToken, accessToken,
onSuccess,
}: RegenerateKeyModalProps) { }: RegenerateKeyModalProps) {
const [form] = Form.useForm(); const [form] = Form.useForm();
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null); const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
const [regenerateFormData, setRegenerateFormData] = useState<any>(null); const [regenerateFormData, setRegenerateFormData] = useState<any>(null);
const [newExpiryTime, setNewExpiryTime] = useState<string | null>(null); const [newExpiryTime, setNewExpiryTime] = useState<string | null>(null);
const [isRegenerating, setIsRegenerating] = useState(false);
useEffect(() => { useEffect(() => {
if (visible && selectedToken) { if (visible && selectedToken) {
@ -38,6 +37,15 @@ export function RegenerateKeyModal({
} }
}, [visible, selectedToken, form]); }, [visible, selectedToken, form]);
useEffect(() => {
if (!visible) {
// Reset states when modal is closed
setRegeneratedKey(null);
setIsRegenerating(false);
form.resetFields();
}
}, [visible, form]);
useEffect(() => { useEffect(() => {
const calculateNewExpiryTime = (duration: string | undefined) => { const calculateNewExpiryTime = (duration: string | undefined) => {
if (!duration) return null; if (!duration) return null;
@ -70,25 +78,24 @@ export function RegenerateKeyModal({
}, [regenerateFormData?.duration]); }, [regenerateFormData?.duration]);
const handleRegenerateKey = async () => { const handleRegenerateKey = async () => {
if (!selectedToken || !accessToken) return; if (!selectedToken || !accessToken) return;
setIsRegenerating(true);
try { try {
const formValues = await form.validateFields(); const formValues = await form.validateFields();
const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues); const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues);
setRegeneratedKey(response.key); setRegeneratedKey(response.key);
if (onSuccess) {
onSuccess({ ...selectedToken, key_name: response.key_name, ...formValues });
}
message.success("API Key regenerated successfully"); message.success("API Key regenerated successfully");
} catch (error) { } catch (error) {
console.error("Error regenerating key:", error); console.error("Error regenerating key:", error);
message.error("Failed to regenerate API Key"); message.error("Failed to regenerate API Key");
setIsRegenerating(false); // Reset regenerating state on error
} }
}; };
const handleClose = () => { const handleClose = () => {
setRegeneratedKey(null); setRegeneratedKey(null);
setIsRegenerating(false);
form.resetFields(); form.resetFields();
onClose(); onClose();
}; };
@ -96,7 +103,7 @@ export function RegenerateKeyModal({
return ( return (
<Modal <Modal
title="Regenerate API Key" title="Regenerate API Key"
visible={visible} open={visible}
onCancel={handleClose} onCancel={handleClose}
footer={regeneratedKey ? [ footer={regeneratedKey ? [
<Button key="close" onClick={handleClose}> <Button key="close" onClick={handleClose}>
@ -106,8 +113,12 @@ export function RegenerateKeyModal({
<Button key="cancel" onClick={handleClose} className="mr-2"> <Button key="cancel" onClick={handleClose} className="mr-2">
Cancel Cancel
</Button>, </Button>,
<Button key="regenerate" onClick={handleRegenerateKey} > <Button
Regenerate key="regenerate"
onClick={handleRegenerateKey}
disabled={isRegenerating}
>
{isRegenerating ? "Regenerating..." : "Regenerate"}
</Button>, </Button>,
]} ]}
> >
@ -142,41 +153,40 @@ export function RegenerateKeyModal({
</Col> </Col>
</Grid> </Grid>
) : ( ) : (
<Form <Form
form={form} form={form}
layout="vertical" layout="vertical"
onValuesChange={(changedValues) => { onValuesChange={(changedValues) => {
if ("duration" in changedValues) { if ("duration" in changedValues) {
setRegenerateFormData((prev: { duration?: string }) => ({ ...prev, duration: changedValues.duration })); setRegenerateFormData((prev: { duration?: string }) => ({ ...prev, duration: changedValues.duration }));
} }
}} }}
> >
<Form.Item name="key_alias" label="Key Alias"> <Form.Item name="key_alias" label="Key Alias">
<TextInput disabled={true} /> <TextInput disabled={true} />
</Form.Item> </Form.Item>
<Form.Item name="max_budget" label="Max Budget (USD)"> <Form.Item name="max_budget" label="Max Budget (USD)">
<InputNumber step={0.01} precision={2} style={{ width: "100%" }} /> <InputNumber step={0.01} precision={2} style={{ width: "100%" }} />
</Form.Item> </Form.Item>
<Form.Item name="tpm_limit" label="TPM Limit"> <Form.Item name="tpm_limit" label="TPM Limit">
<InputNumber style={{ width: "100%" }} /> <InputNumber style={{ width: "100%" }} />
</Form.Item> </Form.Item>
<Form.Item name="rpm_limit" label="RPM Limit"> <Form.Item name="rpm_limit" label="RPM Limit">
<InputNumber style={{ width: "100%" }} /> <InputNumber style={{ width: "100%" }} />
</Form.Item> </Form.Item>
<Form.Item name="duration" label="Expire Key (eg: 30s, 30h, 30d)" className="mt-8"> <Form.Item name="duration" label="Expire Key (eg: 30s, 30h, 30d)" className="mt-8">
<TextInput placeholder="" /> <TextInput placeholder="" />
</Form.Item> </Form.Item>
<div className="mt-2 text-sm text-gray-500"> <div className="mt-2 text-sm text-gray-500">
Current expiry: {selectedToken?.expires ? new Date(selectedToken.expires).toLocaleString() : "Never"} Current expiry: {selectedToken?.expires ? new Date(selectedToken.expires).toLocaleString() : "Never"}
</div>
{newExpiryTime && (
<div className="mt-2 text-sm text-green-600">
New expiry: {newExpiryTime}
</div> </div>
{newExpiryTime && ( )}
<div className="mt-2 text-sm text-green-600"> </Form>
New expiry: {newExpiryTime} )}
</div>
)}
</Form>
)
}
</Modal> </Modal>
); );
} }