mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm stable UI 02 17 2025 p1 (#8599)
* fix(key_management_endpoints.py): initial commit with logic to get all keys for teams user is an admin for * fix(key_managements_endpoints.py): return all keys for teams user is an admin for * fix(key_management_endpoints.py): add query param to ensure user opts into seeing all team keys (not just their own) * fix(regenerate_key_modal.tsx): fix key regenerate * fix(proxy_server.py): fix model metrics check on none api base * test(test_key_generate_prisma.py): remove redundant test * test(test_proxy_utils.py): add unit test covering new management endpoint helper util * fix: fix test * test(test_proxy_server.py): fix test
This commit is contained in:
parent
9826f76288
commit
18bc9ddd3d
10 changed files with 277 additions and 121 deletions
14
litellm/proxy/management_endpoints/common_utils.py
Normal file
14
litellm/proxy/management_endpoints/common_utils.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_team_admin(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||||
|
) -> bool:
|
||||||
|
for member in team_obj.members_with_roles:
|
||||||
|
if (
|
||||||
|
member.user_id is not None and member.user_id == user_api_key_dict.user_id
|
||||||
|
) and member.role == "admin":
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
|
@ -35,6 +35,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
|
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
|
@ -1684,10 +1685,10 @@ async def validate_key_list_check(
|
||||||
organization_id: Optional[str],
|
organization_id: Optional[str],
|
||||||
key_alias: Optional[str],
|
key_alias: Optional[str],
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
):
|
) -> Optional[LiteLLM_UserTable]:
|
||||||
|
|
||||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||||
return
|
return None
|
||||||
|
|
||||||
if user_api_key_dict.user_id is None:
|
if user_api_key_dict.user_id is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -1747,6 +1748,36 @@ async def validate_key_list_check(
|
||||||
param="organization_id",
|
param="organization_id",
|
||||||
code=status.HTTP_403_FORBIDDEN,
|
code=status.HTTP_403_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
return complete_user_info
|
||||||
|
|
||||||
|
|
||||||
|
async def get_admin_team_ids(
|
||||||
|
complete_user_info: Optional[LiteLLM_UserTable],
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get all team IDs where the user is an admin.
|
||||||
|
"""
|
||||||
|
if complete_user_info is None:
|
||||||
|
return []
|
||||||
|
# Get all teams that user is an admin of
|
||||||
|
teams: Optional[List[BaseModel]] = (
|
||||||
|
await prisma_client.db.litellm_teamtable.find_many(
|
||||||
|
where={"team_id": {"in": complete_user_info.teams}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if teams is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
teams_pydantic_obj = [LiteLLM_TeamTable(**team.model_dump()) for team in teams]
|
||||||
|
|
||||||
|
admin_team_ids = [
|
||||||
|
team.team_id
|
||||||
|
for team in teams_pydantic_obj
|
||||||
|
if _is_user_team_admin(user_api_key_dict=user_api_key_dict, team_obj=team)
|
||||||
|
]
|
||||||
|
return admin_team_ids
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
@ -1767,6 +1798,9 @@ async def list_keys(
|
||||||
),
|
),
|
||||||
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
|
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
|
||||||
return_full_object: bool = Query(False, description="Return full key object"),
|
return_full_object: bool = Query(False, description="Return full key object"),
|
||||||
|
include_team_keys: bool = Query(
|
||||||
|
False, description="Include all keys for teams that user is an admin of."
|
||||||
|
),
|
||||||
) -> KeyListResponseObject:
|
) -> KeyListResponseObject:
|
||||||
"""
|
"""
|
||||||
List all keys for a given user / team / organization.
|
List all keys for a given user / team / organization.
|
||||||
|
@ -1782,32 +1816,13 @@ async def list_keys(
|
||||||
try:
|
try:
|
||||||
from litellm.proxy.proxy_server import prisma_client
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
# Check for unsupported parameters
|
|
||||||
supported_params = {
|
|
||||||
"page",
|
|
||||||
"size",
|
|
||||||
"user_id",
|
|
||||||
"team_id",
|
|
||||||
"key_alias",
|
|
||||||
"return_full_object",
|
|
||||||
"organization_id",
|
|
||||||
}
|
|
||||||
unsupported_params = set(request.query_params.keys()) - supported_params
|
|
||||||
if unsupported_params:
|
|
||||||
raise ProxyException(
|
|
||||||
message=f"Unsupported parameter(s): {', '.join(unsupported_params)}. Supported parameters: {', '.join(supported_params)}",
|
|
||||||
type=ProxyErrorTypes.bad_request_error,
|
|
||||||
param=", ".join(unsupported_params),
|
|
||||||
code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Entering list_keys function")
|
verbose_proxy_logger.debug("Entering list_keys function")
|
||||||
|
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
verbose_proxy_logger.error("Database not connected")
|
verbose_proxy_logger.error("Database not connected")
|
||||||
raise Exception("Database not connected")
|
raise Exception("Database not connected")
|
||||||
|
|
||||||
await validate_key_list_check(
|
complete_user_info = await validate_key_list_check(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
|
@ -1816,6 +1831,15 @@ async def list_keys(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if include_team_keys:
|
||||||
|
admin_team_ids = await get_admin_team_ids(
|
||||||
|
complete_user_info=complete_user_info,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
admin_team_ids = None
|
||||||
|
|
||||||
if user_id is None and user_api_key_dict.user_role not in [
|
if user_id is None and user_api_key_dict.user_role not in [
|
||||||
LitellmUserRoles.PROXY_ADMIN.value,
|
LitellmUserRoles.PROXY_ADMIN.value,
|
||||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
|
||||||
|
@ -1831,6 +1855,7 @@ async def list_keys(
|
||||||
key_alias=key_alias,
|
key_alias=key_alias,
|
||||||
return_full_object=return_full_object,
|
return_full_object=return_full_object,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
|
admin_team_ids=admin_team_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Successfully prepared response")
|
verbose_proxy_logger.debug("Successfully prepared response")
|
||||||
|
@ -1866,6 +1891,9 @@ async def _list_key_helper(
|
||||||
key_alias: Optional[str],
|
key_alias: Optional[str],
|
||||||
exclude_team_id: Optional[str] = None,
|
exclude_team_id: Optional[str] = None,
|
||||||
return_full_object: bool = False,
|
return_full_object: bool = False,
|
||||||
|
admin_team_ids: Optional[
|
||||||
|
List[str]
|
||||||
|
] = None, # New parameter for teams where user is admin
|
||||||
) -> KeyListResponseObject:
|
) -> KeyListResponseObject:
|
||||||
"""
|
"""
|
||||||
Helper function to list keys
|
Helper function to list keys
|
||||||
|
@ -1877,6 +1905,7 @@ async def _list_key_helper(
|
||||||
key_alias: Optional[str]
|
key_alias: Optional[str]
|
||||||
exclude_team_id: Optional[str] # exclude a specific team_id
|
exclude_team_id: Optional[str] # exclude a specific team_id
|
||||||
return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token
|
return_full_object: bool # when true, will return UserAPIKeyAuth objects instead of just the token
|
||||||
|
admin_team_ids: Optional[List[str]] # list of team IDs where the user is an admin
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KeyListResponseObject
|
KeyListResponseObject
|
||||||
|
@ -1889,19 +1918,37 @@ async def _list_key_helper(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prepare filter conditions
|
# Prepare filter conditions
|
||||||
where: Dict[str, Union[str, Dict[str, Any]]] = {}
|
where: Dict[str, Union[str, Dict[str, Any], List[Dict[str, Any]]]] = {}
|
||||||
where.update(_get_condition_to_filter_out_ui_session_tokens())
|
where.update(_get_condition_to_filter_out_ui_session_tokens())
|
||||||
|
|
||||||
|
# Build the OR conditions for user's keys and admin team keys
|
||||||
|
or_conditions: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Base conditions for user's own keys
|
||||||
|
user_condition: Dict[str, Any] = {}
|
||||||
if user_id and isinstance(user_id, str):
|
if user_id and isinstance(user_id, str):
|
||||||
where["user_id"] = user_id
|
user_condition["user_id"] = user_id
|
||||||
if team_id and isinstance(team_id, str):
|
if team_id and isinstance(team_id, str):
|
||||||
where["team_id"] = team_id
|
user_condition["team_id"] = team_id
|
||||||
if key_alias and isinstance(key_alias, str):
|
if key_alias and isinstance(key_alias, str):
|
||||||
where["key_alias"] = key_alias
|
user_condition["key_alias"] = key_alias
|
||||||
if exclude_team_id and isinstance(exclude_team_id, str):
|
if exclude_team_id and isinstance(exclude_team_id, str):
|
||||||
where["team_id"] = {"not": exclude_team_id}
|
user_condition["team_id"] = {"not": exclude_team_id}
|
||||||
if organization_id and isinstance(organization_id, str):
|
if organization_id and isinstance(organization_id, str):
|
||||||
where["organization_id"] = organization_id
|
user_condition["organization_id"] = organization_id
|
||||||
|
|
||||||
|
if user_condition:
|
||||||
|
or_conditions.append(user_condition)
|
||||||
|
|
||||||
|
# Add condition for admin team keys if provided
|
||||||
|
if admin_team_ids:
|
||||||
|
or_conditions.append({"team_id": {"in": admin_team_ids}})
|
||||||
|
|
||||||
|
# Combine conditions with OR if we have multiple conditions
|
||||||
|
if len(or_conditions) > 1:
|
||||||
|
where["OR"] = or_conditions
|
||||||
|
elif len(or_conditions) == 1:
|
||||||
|
where.update(or_conditions[0])
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"Filter conditions: {where}")
|
verbose_proxy_logger.debug(f"Filter conditions: {where}")
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_team_object,
|
get_team_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
management_endpoint_wrapper,
|
management_endpoint_wrapper,
|
||||||
|
@ -68,17 +69,6 @@ from litellm.proxy.utils import (
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _is_user_team_admin(
|
|
||||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
|
||||||
) -> bool:
|
|
||||||
|
|
||||||
for member in team_obj.members_with_roles:
|
|
||||||
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||||
if litellm.default_internal_user_params is None:
|
if litellm.default_internal_user_params is None:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -6472,9 +6472,9 @@ async def model_metrics(
|
||||||
if _day not in _daily_entries:
|
if _day not in _daily_entries:
|
||||||
_daily_entries[_day] = {}
|
_daily_entries[_day] = {}
|
||||||
_combined_model_name = str(_model)
|
_combined_model_name = str(_model)
|
||||||
if "https://" in _api_base:
|
if _api_base is not None and "https://" in _api_base:
|
||||||
_combined_model_name = str(_api_base)
|
_combined_model_name = str(_api_base)
|
||||||
if "/openai/" in _combined_model_name:
|
if _combined_model_name is not None and "/openai/" in _combined_model_name:
|
||||||
_combined_model_name = _combined_model_name.split("/openai/")[0]
|
_combined_model_name = _combined_model_name.split("/openai/")[0]
|
||||||
|
|
||||||
_all_api_bases.add(_combined_model_name)
|
_all_api_bases.add(_combined_model_name)
|
||||||
|
|
|
@ -3436,36 +3436,6 @@ async def test_list_keys(prisma_client):
|
||||||
assert _key in response["keys"]
|
assert _key in response["keys"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_key_list_unsupported_params(prisma_client):
|
|
||||||
"""
|
|
||||||
Test the list_keys function:
|
|
||||||
- Test unsupported params
|
|
||||||
"""
|
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import hash_token
|
|
||||||
|
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
|
||||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
|
||||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
|
||||||
|
|
||||||
request = Request(scope={"type": "http", "query_string": b"alias=foo"})
|
|
||||||
|
|
||||||
try:
|
|
||||||
await list_keys(
|
|
||||||
request,
|
|
||||||
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value),
|
|
||||||
page=1,
|
|
||||||
size=10,
|
|
||||||
)
|
|
||||||
pytest.fail("Expected this call to fail")
|
|
||||||
except Exception as e:
|
|
||||||
print("error str=", str(e.message))
|
|
||||||
error_str = str(e.message)
|
|
||||||
assert "Unsupported parameter" in error_str
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_auth_vertex_ai_route(prisma_client):
|
async def test_auth_vertex_ai_route(prisma_client):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1232,6 +1232,7 @@ async def test_create_team_member_add_team_admin(
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
if user_role == "user":
|
if user_role == "user":
|
||||||
assert e.status_code == 403
|
assert e.status_code == 403
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional, List
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
|
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
|
||||||
import json
|
import json
|
||||||
|
@ -1618,6 +1618,10 @@ def test_provider_specific_header():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_UserTable
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"wildcard_model, expected_models",
|
"wildcard_model, expected_models",
|
||||||
[
|
[
|
||||||
|
@ -1642,7 +1646,8 @@ def test_get_known_models_from_wildcard(wildcard_model, expected_models):
|
||||||
print(f"Missing expected model: {model}")
|
print(f"Missing expected model: {model}")
|
||||||
|
|
||||||
assert all(model in wildcard_models for model in expected_models)
|
assert all(model in wildcard_models for model in expected_models)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"data, user_api_key_dict, expected_model",
|
"data, user_api_key_dict, expected_model",
|
||||||
[
|
[
|
||||||
|
@ -1692,3 +1697,125 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod
|
||||||
# Check if model was updated correctly
|
# Check if model was updated correctly
|
||||||
assert test_data.get("model") == expected_model
|
assert test_data.get("model") == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prisma_client():
|
||||||
|
client = MagicMock()
|
||||||
|
client.db = MagicMock()
|
||||||
|
client.db.litellm_teamtable = AsyncMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_id, user_info, user_role, mock_teams, expected_teams, should_query_db",
|
||||||
|
[
|
||||||
|
("no_user_info", None, "proxy_admin", None, [], False),
|
||||||
|
(
|
||||||
|
"no_teams_found",
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
teams=["team1", "team2"],
|
||||||
|
user_id="user1",
|
||||||
|
max_budget=100,
|
||||||
|
spend=0,
|
||||||
|
user_email="user1@example.com",
|
||||||
|
user_role="proxy_admin",
|
||||||
|
),
|
||||||
|
"proxy_admin",
|
||||||
|
None,
|
||||||
|
[],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"admin_user_with_teams",
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
teams=["team1", "team2"],
|
||||||
|
user_id="user1",
|
||||||
|
max_budget=100,
|
||||||
|
spend=0,
|
||||||
|
user_email="user1@example.com",
|
||||||
|
user_role="proxy_admin",
|
||||||
|
),
|
||||||
|
"proxy_admin",
|
||||||
|
[
|
||||||
|
MagicMock(
|
||||||
|
model_dump=lambda: {
|
||||||
|
"team_id": "team1",
|
||||||
|
"members_with_roles": [{"role": "admin", "user_id": "user1"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
model_dump=lambda: {
|
||||||
|
"team_id": "team2",
|
||||||
|
"members_with_roles": [
|
||||||
|
{"role": "admin", "user_id": "user1"},
|
||||||
|
{"role": "user", "user_id": "user2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
["team1", "team2"],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"non_admin_user",
|
||||||
|
LiteLLM_UserTable(
|
||||||
|
teams=["team1", "team2"],
|
||||||
|
user_id="user1",
|
||||||
|
max_budget=100,
|
||||||
|
spend=0,
|
||||||
|
user_email="user1@example.com",
|
||||||
|
user_role="internal_user",
|
||||||
|
),
|
||||||
|
"internal_user",
|
||||||
|
[
|
||||||
|
MagicMock(
|
||||||
|
model_dump=lambda: {"team_id": "team1", "members": ["user1"]}
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
model_dump=lambda: {
|
||||||
|
"team_id": "team2",
|
||||||
|
"members": ["user1", "user2"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_get_admin_team_ids(
|
||||||
|
test_id: str,
|
||||||
|
user_info: Optional[LiteLLM_UserTable],
|
||||||
|
user_role: str,
|
||||||
|
mock_teams: Optional[List[MagicMock]],
|
||||||
|
expected_teams: List[str],
|
||||||
|
should_query_db: bool,
|
||||||
|
mock_prisma_client,
|
||||||
|
):
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
|
get_admin_team_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
mock_prisma_client.db.litellm_teamtable.find_many.return_value = mock_teams
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
user_role=user_role, user_id=user_info.user_id if user_info else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await get_admin_team_ids(
|
||||||
|
complete_user_info=user_info,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
prisma_client=mock_prisma_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == expected_teams, f"Expected {expected_teams}, but got {result}"
|
||||||
|
|
||||||
|
if should_query_db:
|
||||||
|
mock_prisma_client.db.litellm_teamtable.find_many.assert_called_once_with(
|
||||||
|
where={"team_id": {"in": user_info.teams}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mock_prisma_client.db.litellm_teamtable.find_many.assert_not_called()
|
||||||
|
|
|
@ -154,10 +154,6 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
|
||||||
visible={isRegenerateModalOpen}
|
visible={isRegenerateModalOpen}
|
||||||
onClose={() => setIsRegenerateModalOpen(false)}
|
onClose={() => setIsRegenerateModalOpen(false)}
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
onSuccess={(newKeyData) => {
|
|
||||||
// Handle the updated key data here if needed
|
|
||||||
setIsRegenerateModalOpen(false);
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{/* Delete Confirmation Modal */}
|
{/* Delete Confirmation Modal */}
|
||||||
|
|
|
@ -2176,6 +2176,7 @@ export const keyListCall = async (
|
||||||
}
|
}
|
||||||
|
|
||||||
queryParams.append('return_full_object', 'true');
|
queryParams.append('return_full_object', 'true');
|
||||||
|
queryParams.append('include_team_keys', 'true');
|
||||||
|
|
||||||
const queryString = queryParams.toString();
|
const queryString = queryParams.toString();
|
||||||
if (queryString) {
|
if (queryString) {
|
||||||
|
|
|
@ -11,7 +11,6 @@ interface RegenerateKeyModalProps {
|
||||||
visible: boolean;
|
visible: boolean;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
accessToken: string | null;
|
accessToken: string | null;
|
||||||
onSuccess?: (newKeyData: any) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function RegenerateKeyModal({
|
export function RegenerateKeyModal({
|
||||||
|
@ -19,12 +18,12 @@ export function RegenerateKeyModal({
|
||||||
visible,
|
visible,
|
||||||
onClose,
|
onClose,
|
||||||
accessToken,
|
accessToken,
|
||||||
onSuccess,
|
|
||||||
}: RegenerateKeyModalProps) {
|
}: RegenerateKeyModalProps) {
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
|
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
|
||||||
const [regenerateFormData, setRegenerateFormData] = useState<any>(null);
|
const [regenerateFormData, setRegenerateFormData] = useState<any>(null);
|
||||||
const [newExpiryTime, setNewExpiryTime] = useState<string | null>(null);
|
const [newExpiryTime, setNewExpiryTime] = useState<string | null>(null);
|
||||||
|
const [isRegenerating, setIsRegenerating] = useState(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (visible && selectedToken) {
|
if (visible && selectedToken) {
|
||||||
|
@ -38,6 +37,15 @@ export function RegenerateKeyModal({
|
||||||
}
|
}
|
||||||
}, [visible, selectedToken, form]);
|
}, [visible, selectedToken, form]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!visible) {
|
||||||
|
// Reset states when modal is closed
|
||||||
|
setRegeneratedKey(null);
|
||||||
|
setIsRegenerating(false);
|
||||||
|
form.resetFields();
|
||||||
|
}
|
||||||
|
}, [visible, form]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const calculateNewExpiryTime = (duration: string | undefined) => {
|
const calculateNewExpiryTime = (duration: string | undefined) => {
|
||||||
if (!duration) return null;
|
if (!duration) return null;
|
||||||
|
@ -70,25 +78,24 @@ export function RegenerateKeyModal({
|
||||||
}, [regenerateFormData?.duration]);
|
}, [regenerateFormData?.duration]);
|
||||||
|
|
||||||
const handleRegenerateKey = async () => {
|
const handleRegenerateKey = async () => {
|
||||||
|
|
||||||
if (!selectedToken || !accessToken) return;
|
if (!selectedToken || !accessToken) return;
|
||||||
|
|
||||||
|
setIsRegenerating(true);
|
||||||
try {
|
try {
|
||||||
const formValues = await form.validateFields();
|
const formValues = await form.validateFields();
|
||||||
const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues);
|
const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues);
|
||||||
setRegeneratedKey(response.key);
|
setRegeneratedKey(response.key);
|
||||||
if (onSuccess) {
|
|
||||||
onSuccess({ ...selectedToken, key_name: response.key_name, ...formValues });
|
|
||||||
}
|
|
||||||
message.success("API Key regenerated successfully");
|
message.success("API Key regenerated successfully");
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error regenerating key:", error);
|
console.error("Error regenerating key:", error);
|
||||||
message.error("Failed to regenerate API Key");
|
message.error("Failed to regenerate API Key");
|
||||||
|
setIsRegenerating(false); // Reset regenerating state on error
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
setRegeneratedKey(null);
|
setRegeneratedKey(null);
|
||||||
|
setIsRegenerating(false);
|
||||||
form.resetFields();
|
form.resetFields();
|
||||||
onClose();
|
onClose();
|
||||||
};
|
};
|
||||||
|
@ -96,7 +103,7 @@ export function RegenerateKeyModal({
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
title="Regenerate API Key"
|
title="Regenerate API Key"
|
||||||
visible={visible}
|
open={visible}
|
||||||
onCancel={handleClose}
|
onCancel={handleClose}
|
||||||
footer={regeneratedKey ? [
|
footer={regeneratedKey ? [
|
||||||
<Button key="close" onClick={handleClose}>
|
<Button key="close" onClick={handleClose}>
|
||||||
|
@ -106,8 +113,12 @@ export function RegenerateKeyModal({
|
||||||
<Button key="cancel" onClick={handleClose} className="mr-2">
|
<Button key="cancel" onClick={handleClose} className="mr-2">
|
||||||
Cancel
|
Cancel
|
||||||
</Button>,
|
</Button>,
|
||||||
<Button key="regenerate" onClick={handleRegenerateKey} >
|
<Button
|
||||||
Regenerate
|
key="regenerate"
|
||||||
|
onClick={handleRegenerateKey}
|
||||||
|
disabled={isRegenerating}
|
||||||
|
>
|
||||||
|
{isRegenerating ? "Regenerating..." : "Regenerate"}
|
||||||
</Button>,
|
</Button>,
|
||||||
]}
|
]}
|
||||||
>
|
>
|
||||||
|
@ -142,41 +153,40 @@ export function RegenerateKeyModal({
|
||||||
</Col>
|
</Col>
|
||||||
</Grid>
|
</Grid>
|
||||||
) : (
|
) : (
|
||||||
<Form
|
<Form
|
||||||
form={form}
|
form={form}
|
||||||
layout="vertical"
|
layout="vertical"
|
||||||
onValuesChange={(changedValues) => {
|
onValuesChange={(changedValues) => {
|
||||||
if ("duration" in changedValues) {
|
if ("duration" in changedValues) {
|
||||||
setRegenerateFormData((prev: { duration?: string }) => ({ ...prev, duration: changedValues.duration }));
|
setRegenerateFormData((prev: { duration?: string }) => ({ ...prev, duration: changedValues.duration }));
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Form.Item name="key_alias" label="Key Alias">
|
<Form.Item name="key_alias" label="Key Alias">
|
||||||
<TextInput disabled={true} />
|
<TextInput disabled={true} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="max_budget" label="Max Budget (USD)">
|
<Form.Item name="max_budget" label="Max Budget (USD)">
|
||||||
<InputNumber step={0.01} precision={2} style={{ width: "100%" }} />
|
<InputNumber step={0.01} precision={2} style={{ width: "100%" }} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="tpm_limit" label="TPM Limit">
|
<Form.Item name="tpm_limit" label="TPM Limit">
|
||||||
<InputNumber style={{ width: "100%" }} />
|
<InputNumber style={{ width: "100%" }} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="rpm_limit" label="RPM Limit">
|
<Form.Item name="rpm_limit" label="RPM Limit">
|
||||||
<InputNumber style={{ width: "100%" }} />
|
<InputNumber style={{ width: "100%" }} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item name="duration" label="Expire Key (eg: 30s, 30h, 30d)" className="mt-8">
|
<Form.Item name="duration" label="Expire Key (eg: 30s, 30h, 30d)" className="mt-8">
|
||||||
<TextInput placeholder="" />
|
<TextInput placeholder="" />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<div className="mt-2 text-sm text-gray-500">
|
<div className="mt-2 text-sm text-gray-500">
|
||||||
Current expiry: {selectedToken?.expires ? new Date(selectedToken.expires).toLocaleString() : "Never"}
|
Current expiry: {selectedToken?.expires ? new Date(selectedToken.expires).toLocaleString() : "Never"}
|
||||||
|
</div>
|
||||||
|
{newExpiryTime && (
|
||||||
|
<div className="mt-2 text-sm text-green-600">
|
||||||
|
New expiry: {newExpiryTime}
|
||||||
</div>
|
</div>
|
||||||
{newExpiryTime && (
|
)}
|
||||||
<div className="mt-2 text-sm text-green-600">
|
</Form>
|
||||||
New expiry: {newExpiryTime}
|
)}
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</Form>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
</Modal>
|
</Modal>
|
||||||
);
|
);
|
||||||
}
|
}
|
Loading…
Add table
Add a link
Reference in a new issue