mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm stable UI 02 17 2025 p1 (#8599)
* fix(key_management_endpoints.py): initial commit with logic to get all keys for teams user is an admin for * fix(key_managements_endpoints.py): return all keys for teams user is an admin for * fix(key_management_endpoints.py): add query param to ensure user opts into seeing all team keys (not just their own) * fix(regenerate_key_modal.tsx): fix key regenerate * fix(proxy_server.py): fix model metrics check on none api base * test(test_key_generate_prisma.py): remove redundant test * test(test_proxy_utils.py): add unit test covering new management endpoint helper util * fix: fix test * test(test_proxy_server.py): fix test
This commit is contained in:
parent
9826f76288
commit
18bc9ddd3d
10 changed files with 277 additions and 121 deletions
14
litellm/proxy/management_endpoints/common_utils.py
Normal file
14
litellm/proxy/management_endpoints/common_utils.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
for member in team_obj.members_with_roles:
|
||||
if (
|
||||
member.user_id is not None and member.user_id == user_api_key_dict.user_id
|
||||
) and member.role == "admin":
|
||||
|
||||
return True
|
||||
|
||||
return False
|
|
@ -35,6 +35,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
|
@ -1684,10 +1685,10 @@ async def validate_key_list_check(
|
|||
organization_id: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
return None
|
||||
|
||||
if user_api_key_dict.user_id is None:
|
||||
raise ProxyException(
|
||||
|
@ -1747,6 +1748,36 @@ async def validate_key_list_check(
|
|||
param="organization_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
return complete_user_info
|
||||
|
||||
|
||||
async def get_admin_team_ids(
|
||||
complete_user_info: Optional[LiteLLM_UserTable],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all team IDs where the user is an admin.
|
||||
"""
|
||||
if complete_user_info is None:
|
||||
return []
|
||||
# Get all teams that user is an admin of
|
||||
teams: Optional[List[BaseModel]] = (
|
||||
await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": complete_user_info.teams}}
|
||||
)
|
||||
)
|
||||
if teams is None:
|
||||
return []
|
||||
|
||||
teams_pydantic_obj = [LiteLLM_TeamTable(**team.model_dump()) for team in teams]
|
||||
|
||||
admin_team_ids = [
|
||||
team.team_id
|
||||
for team in teams_pydantic_obj
|
||||
if _is_user_team_admin(user_api_key_dict=user_api_key_dict, team_obj=team)
|
||||
]
|
||||
return admin_team_ids
|
||||
|
||||
|
||||
@router.get(
|
||||
|
@ -1767,6 +1798,9 @@ async def list_keys(
|
|||
),
|
||||
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
|
||||
return_full_object: bool = Query(False, description="Return full key object"),
|
||||
include_team_keys: bool = Query(
|
||||
False, description="Include all keys for teams that user is an admin of."
|
||||
),
|
||||
) -> KeyListResponseObject:
|
||||
"""
|
||||
List all keys for a given user / team / organization.
|
||||
|
@ -1782,32 +1816,13 @@ async def list_keys(
|
|||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
# Check for unsupported parameters
|
||||
supported_params = {
|
||||
"page",
|
||||
"size",
|
||||
"user_id",
|
||||
"team_id",
|
||||
"key_alias",
|
||||
"return_full_object",
|
||||
"organization_id",
|
||||
}
|
||||
unsupported_params = set(request.query_params.keys()) - supported_params
|
||||
if unsupported_params:
|
||||
raise ProxyException(
|
||||
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param=", ".join(unsupported_params),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Entering list_keys function")
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_proxy_logger.error("Database not connected")
|
||||
raise Exception("Database not connected")
|
||||
|
||||
await validate_key_list_check(
|
||||
complete_user_info = await validate_key_list_check(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
|
@ -1816,6 +1831,15 @@ async def list_keys(
|
|||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
if include_team_keys:
|
||||
admin_team_ids = await get_admin_team_ids(
|
||||
complete_user_info=complete_user_info,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
admin_team_ids = None
|
||||
|
||||
if user_id is None and user_api_key_dict.user_role not in [
|
||||
LitellmUserRoles.PROXY_ADMIN.value,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||
|
@ -1831,6 +1855,7 @@ async def list_keys(
|
|||
key_alias=key_alias,
|
||||
return_full_object=return_full_object,
|
||||
organization_id=organization_id,
|
||||
admin_team_ids=admin_team_ids,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Successfully prepared response")
|
||||
|
@ -1866,6 +1891,9 @@ async def _list_key_helper(
|
|||
key_alias: Optional[str],
|
||||
exclude_team_id: Optional[str] = None,
|
||||
return_full_object: bool = False,
|
||||
admin_team_ids: Optional[
|
||||
List[str]
|
||||
] = None, # New parameter for teams where user is admin
|
||||
) -> KeyListResponseObject:
|
||||
"""
|
||||
Helper function to list keys
|
||||
|
@ -1877,6 +1905,7 @@ async def _list_key_helper(
|
|||
key_alias: Optional[str]
|
||||
exclude_team_id: Optional[str] # exclude a specific team_id
|
||||
return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token
|
||||
admin_team_ids: Optional[List[str]] # list of team IDs where the user is an admin
|
||||
|
||||
Returns:
|
||||
KeyListResponseObject
|
||||
|
@ -1889,19 +1918,37 @@ async def _list_key_helper(
|
|||
"""
|
||||
|
||||
# Prepare filter conditions
|
||||
where: Dict[str, Union[str, Dict[str, Any]]] = {}
|
||||
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
|
||||
where.update(_get_condition_to_filter_out_ui_session_tokens())
|
||||
|
||||
# Build the OR conditions for user's keys and admin team keys
|
||||
or_conditions: List[Dict[str, Any]] = []
|
||||
|
||||
# Base conditions for user's own keys
|
||||
user_condition: Dict[str, Any] = {}
|
||||
if user_id and isinstance(user_id, str):
|
||||
where["user_id"] = user_id
|
||||
user_condition["user_id"] = user_id
|
||||
if team_id and isinstance(team_id, str):
|
||||
where["team_id"] = team_id
|
||||
user_condition["team_id"] = team_id
|
||||
if key_alias and isinstance(key_alias, str):
|
||||
where["key_alias"] = key_alias
|
||||
user_condition["key_alias"] = key_alias
|
||||
if exclude_team_id and isinstance(exclude_team_id, str):
|
||||
where["team_id"] = {"not": exclude_team_id}
|
||||
user_condition["team_id"] = {"not": exclude_team_id}
|
||||
if organization_id and isinstance(organization_id, str):
|
||||
where["organization_id"] = organization_id
|
||||
user_condition["organization_id"] = organization_id
|
||||
|
||||
if user_condition:
|
||||
or_conditions.append(user_condition)
|
||||
|
||||
# Add condition for admin team keys if provided
|
||||
if admin_team_ids:
|
||||
or_conditions.append({"team_id": {"in": admin_team_ids}})
|
||||
|
||||
# Combine conditions with OR if we have multiple conditions
|
||||
if len(or_conditions) > 1:
|
||||
where["OR"] = or_conditions
|
||||
elif len(or_conditions) == 1:
|
||||
where.update(or_conditions[0])
|
||||
|
||||
verbose_proxy_logger.debug(f"Filter conditions: {where}")
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_team_object,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
|
@ -68,17 +69,6 @@ from litellm.proxy.utils import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
|
||||
for member in team_obj.members_with_roles:
|
||||
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
if litellm.default_internal_user_params is None:
|
||||
return False
|
||||
|
|
|
@ -6472,9 +6472,9 @@ async def model_metrics(
|
|||
if _day not in _daily_entries:
|
||||
_daily_entries[_day] = {}
|
||||
_combined_model_name = str(_model)
|
||||
if "https://" in _api_base:
|
||||
if _api_base is not None and "https://" in _api_base:
|
||||
_combined_model_name = str(_api_base)
|
||||
if "/openai/" in _combined_model_name:
|
||||
if _combined_model_name is not None and "/openai/" in _combined_model_name:
|
||||
_combined_model_name = _combined_model_name.split("/openai/")[0]
|
||||
|
||||
_all_api_bases.add(_combined_model_name)
|
||||
|
|
|
@ -3436,36 +3436,6 @@ async def test_list_keys(prisma_client):
|
|||
assert _key in response["keys"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_list_unsupported_params(prisma_client):
|
||||
"""
|
||||
Test the list_keys function:
|
||||
- Test unsupported params
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
||||
|
||||
try:
|
||||
await list_keys(
|
||||
request,
|
||||
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
||||
page=1,
|
||||
size=10,
|
||||
)
|
||||
pytest.fail("Expected this call to fail")
|
||||
except Exception as e:
|
||||
print("error str=", str(e.message))
|
||||
error_str = str(e.message)
|
||||
assert "Unsupported parameter" in error_str
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_vertex_ai_route(prisma_client):
|
||||
"""
|
||||
|
|
|
@ -1232,6 +1232,7 @@ async def test_create_team_member_add_team_admin(
|
|||
except HTTPException as e:
|
||||
if user_role == "user":
|
||||
assert e.status_code == 403
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional, List
|
||||
from unittest.mock import Mock
|
||||
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
|
||||
import json
|
||||
|
@ -1618,6 +1618,10 @@ def test_provider_specific_header():
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
from litellm.proxy._types import LiteLLM_UserTable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wildcard_model, expected_models",
|
||||
[
|
||||
|
@ -1643,6 +1647,7 @@ def test_get_known_models_from_wildcard(wildcard_model, expected_models):
|
|||
|
||||
assert all(model in wildcard_models for model in expected_models)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, user_api_key_dict, expected_model",
|
||||
[
|
||||
|
@ -1692,3 +1697,125 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod
|
|||
# Check if model was updated correctly
|
||||
assert test_data.get("model") == expected_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prisma_client():
|
||||
client = MagicMock()
|
||||
client.db = MagicMock()
|
||||
client.db.litellm_teamtable = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_id, user_info, user_role, mock_teams, expected_teams, should_query_db",
|
||||
[
|
||||
("no_user_info", None, "proxy_admin", None, [], False),
|
||||
(
|
||||
"no_teams_found",
|
||||
LiteLLM_UserTable(
|
||||
teams=["team1", "team2"],
|
||||
user_id="user1",
|
||||
max_budget=100,
|
||||
spend=0,
|
||||
user_email="user1@example.com",
|
||||
user_role="proxy_admin",
|
||||
),
|
||||
"proxy_admin",
|
||||
None,
|
||||
[],
|
||||
True,
|
||||
),
|
||||
(
|
||||
"admin_user_with_teams",
|
||||
LiteLLM_UserTable(
|
||||
teams=["team1", "team2"],
|
||||
user_id="user1",
|
||||
max_budget=100,
|
||||
spend=0,
|
||||
user_email="user1@example.com",
|
||||
user_role="proxy_admin",
|
||||
),
|
||||
"proxy_admin",
|
||||
[
|
||||
MagicMock(
|
||||
model_dump=lambda: {
|
||||
"team_id": "team1",
|
||||
"members_with_roles": [{"role": "admin", "user_id": "user1"}],
|
||||
}
|
||||
),
|
||||
MagicMock(
|
||||
model_dump=lambda: {
|
||||
"team_id": "team2",
|
||||
"members_with_roles": [
|
||||
{"role": "admin", "user_id": "user1"},
|
||||
{"role": "user", "user_id": "user2"},
|
||||
],
|
||||
}
|
||||
),
|
||||
],
|
||||
["team1", "team2"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
"non_admin_user",
|
||||
LiteLLM_UserTable(
|
||||
teams=["team1", "team2"],
|
||||
user_id="user1",
|
||||
max_budget=100,
|
||||
spend=0,
|
||||
user_email="user1@example.com",
|
||||
user_role="internal_user",
|
||||
),
|
||||
"internal_user",
|
||||
[
|
||||
MagicMock(
|
||||
model_dump=lambda: {"team_id": "team1", "members": ["user1"]}
|
||||
),
|
||||
MagicMock(
|
||||
model_dump=lambda: {
|
||||
"team_id": "team2",
|
||||
"members": ["user1", "user2"],
|
||||
}
|
||||
),
|
||||
],
|
||||
[],
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_get_admin_team_ids(
|
||||
test_id: str,
|
||||
user_info: Optional[LiteLLM_UserTable],
|
||||
user_role: str,
|
||||
mock_teams: Optional[List[MagicMock]],
|
||||
expected_teams: List[str],
|
||||
should_query_db: bool,
|
||||
mock_prisma_client,
|
||||
):
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
get_admin_team_ids,
|
||||
)
|
||||
|
||||
# Setup
|
||||
mock_prisma_client.db.litellm_teamtable.find_many.return_value = mock_teams
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
user_role=user_role, user_id=user_info.user_id if user_info else None
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await get_admin_team_ids(
|
||||
complete_user_info=user_info,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=mock_prisma_client,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == expected_teams, f"Expected {expected_teams}, but got {result}"
|
||||
|
||||
if should_query_db:
|
||||
mock_prisma_client.db.litellm_teamtable.find_many.assert_called_once_with(
|
||||
where={"team_id": {"in": user_info.teams}}
|
||||
)
|
||||
else:
|
||||
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()
|
||||
|
|
|
@ -154,10 +154,6 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
|
|||
visible={isRegenerateModalOpen}
|
||||
onClose={() => setIsRegenerateModalOpen(false)}
|
||||
accessToken={accessToken}
|
||||
onSuccess={(newKeyData) => {
|
||||
// Handle the updated key data here if needed
|
||||
setIsRegenerateModalOpen(false);
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Delete Confirmation Modal */}
|
||||
|
|
|
@ -2176,6 +2176,7 @@ export const keyListCall = async (
|
|||
}
|
||||
|
||||
queryParams.append('return_full_object', 'true');
|
||||
queryParams.append('include_team_keys', 'true');
|
||||
|
||||
const queryString = queryParams.toString();
|
||||
if (queryString) {
|
||||
|
|
|
@ -11,7 +11,6 @@ interface RegenerateKeyModalProps {
|
|||
visible: boolean;
|
||||
onClose: () => void;
|
||||
accessToken: string | null;
|
||||
onSuccess?: (newKeyData: any) => void;
|
||||
}
|
||||
|
||||
export function RegenerateKeyModal({
|
||||
|
@ -19,12 +18,12 @@ export function RegenerateKeyModal({
|
|||
visible,
|
||||
onClose,
|
||||
accessToken,
|
||||
onSuccess,
|
||||
}: RegenerateKeyModalProps) {
|
||||
const [form] = Form.useForm();
|
||||
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
|
||||
const [regenerateFormData, setRegenerateFormData] = useState<any>(null);
|
||||
const [newExpiryTime, setNewExpiryTime] = useState<string | null>(null);
|
||||
const [isRegenerating, setIsRegenerating] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (visible && selectedToken) {
|
||||
|
@ -38,6 +37,15 @@ export function RegenerateKeyModal({
|
|||
}
|
||||
}, [visible, selectedToken, form]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!visible) {
|
||||
// Reset states when modal is closed
|
||||
setRegeneratedKey(null);
|
||||
setIsRegenerating(false);
|
||||
form.resetFields();
|
||||
}
|
||||
}, [visible, form]);
|
||||
|
||||
useEffect(() => {
|
||||
const calculateNewExpiryTime = (duration: string | undefined) => {
|
||||
if (!duration) return null;
|
||||
|
@ -70,25 +78,24 @@ export function RegenerateKeyModal({
|
|||
}, [regenerateFormData?.duration]);
|
||||
|
||||
const handleRegenerateKey = async () => {
|
||||
|
||||
if (!selectedToken || !accessToken) return;
|
||||
|
||||
setIsRegenerating(true);
|
||||
try {
|
||||
const formValues = await form.validateFields();
|
||||
const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues);
|
||||
setRegeneratedKey(response.key);
|
||||
if (onSuccess) {
|
||||
onSuccess({ ...selectedToken, key_name: response.key_name, ...formValues });
|
||||
}
|
||||
message.success("API Key regenerated successfully");
|
||||
} catch (error) {
|
||||
console.error("Error regenerating key:", error);
|
||||
message.error("Failed to regenerate API Key");
|
||||
setIsRegenerating(false); // Reset regenerating state on error
|
||||
}
|
||||
};
|
||||
|
||||
const handleClose = () => {
|
||||
setRegeneratedKey(null);
|
||||
setIsRegenerating(false);
|
||||
form.resetFields();
|
||||
onClose();
|
||||
};
|
||||
|
@ -96,7 +103,7 @@ export function RegenerateKeyModal({
|
|||
return (
|
||||
<Modal
|
||||
title="Regenerate API Key"
|
||||
visible={visible}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
footer={regeneratedKey ? [
|
||||
<Button key="close" onClick={handleClose}>
|
||||
|
@ -106,8 +113,12 @@ export function RegenerateKeyModal({
|
|||
<Button key="cancel" onClick={handleClose} className="mr-2">
|
||||
Cancel
|
||||
</Button>,
|
||||
<Button key="regenerate" onClick={handleRegenerateKey} >
|
||||
Regenerate
|
||||
<Button
|
||||
key="regenerate"
|
||||
onClick={handleRegenerateKey}
|
||||
disabled={isRegenerating}
|
||||
>
|
||||
{isRegenerating ? "Regenerating..." : "Regenerate"}
|
||||
</Button>,
|
||||
]}
|
||||
>
|
||||
|
@ -175,8 +186,7 @@ export function RegenerateKeyModal({
|
|||
</div>
|
||||
)}
|
||||
</Form>
|
||||
)
|
||||
}
|
||||
)}
|
||||
</Modal>
|
||||
);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue