diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 9b559adae..dc1d6aac2 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -656,6 +656,14 @@ class UpdateKeyRequest(GenerateKeyRequest): metadata: Optional[dict] = None +class RegenerateKeyRequest(GenerateKeyRequest): + # This needs to be different from UpdateKeyRequest, because "key" is optional for this + key: Optional[str] = None + duration: Optional[str] = None + spend: Optional[float] = None + metadata: Optional[dict] = None + + class KeyRequest(LiteLLMBase): keys: List[str] diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 00e17400c..553cdb177 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -280,6 +280,54 @@ async def generate_key_fn( ) +async def prepare_key_update_data( + data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row +): + data_json: dict = data.dict(exclude_unset=True) + key = data_json.pop("key", None) + + _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] + non_default_values = {} + for k, v in data_json.items(): + if k in _metadata_fields: + continue + if v is not None and v not in ([], {}, 0): + non_default_values[k] = v + + if "duration" in non_default_values: + duration = non_default_values.pop("duration") + duration_s = _duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["expires"] = expires + + if "budget_duration" in non_default_values: + duration_s = _duration_in_seconds( + duration=non_default_values["budget_duration"] + ) + key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = key_reset_at + + _metadata = existing_key_row.metadata or {} + + if data.model_tpm_limit: + if "model_tpm_limit" not in _metadata: + _metadata["model_tpm_limit"] = {} + _metadata["model_tpm_limit"].update(data.model_tpm_limit) + non_default_values["metadata"] = _metadata + + if data.model_rpm_limit: + if "model_rpm_limit" not in _metadata: + _metadata["model_rpm_limit"] = {} + _metadata["model_rpm_limit"].update(data.model_rpm_limit) + non_default_values["metadata"] = _metadata + + if data.guardrails: + _metadata["guardrails"] = data.guardrails + non_default_values["metadata"] = _metadata + + return non_default_values + + @router.post( "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) @@ -323,59 +371,9 @@ async def update_key_fn( detail={"error": f"Team not found, passed team_id={data.team_id}"}, ) - _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] - # get non default values for key - non_default_values = {} - for k, v in data_json.items(): - # this field gets stored in metadata - if key in _metadata_fields: - continue - if v is not None and v not in ( - [], - {}, - 0, - ): # models default to [], spend defaults to 0, we should not reset these values - non_default_values[k] = v - - if "duration" in non_default_values: - duration = non_default_values.pop("duration") - duration_s = _duration_in_seconds(duration=duration) - expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["expires"] = expires - - if "budget_duration" in non_default_values: - duration_s = _duration_in_seconds( - duration=non_default_values["budget_duration"] - ) - key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["budget_reset_at"] = key_reset_at - - # Update metadata for virtual Key - if data.model_tpm_limit: - _metadata = existing_key_row.metadata or {} - if "model_tpm_limit" not in _metadata: - _metadata["model_tpm_limit"] = {} - - _metadata["model_tpm_limit"].update(data.model_tpm_limit) - non_default_values["metadata"] = _metadata - non_default_values.pop("model_tpm_limit", None) - - if data.model_rpm_limit: - _metadata = existing_key_row.metadata or {} - if "model_rpm_limit" not in _metadata: - _metadata["model_rpm_limit"] = {} - - _metadata["model_rpm_limit"].update(data.model_rpm_limit) - non_default_values["metadata"] = _metadata - non_default_values.pop("model_rpm_limit", None) - - if data.guardrails: - _metadata = existing_key_row.metadata or {} - _metadata["guardrails"] = data.guardrails - - # update values that will be written to the DB - non_default_values["metadata"] = _metadata - non_default_values.pop("guardrails", None) + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=existing_key_row + ) response = await prisma_client.update_data( token=key, data={**non_default_values, "token": key} @@ -983,6 +981,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None) @management_endpoint_wrapper async def regenerate_key_fn( key: str, + data: Optional[RegenerateKeyRequest] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), litellm_changed_by: Optional[str] = Header( None, @@ -1041,14 +1040,26 @@ async def regenerate_key_fn( new_token_hash = hash_token(new_token) new_token_key_name = f"sk-...{new_token[-4:]}" - # update new token in DB + # Prepare the update data + update_data = { + "token": new_token_hash, + "key_name": new_token_key_name, + } + + non_default_values = {} + if data is not None: + # Update with any provided parameters from GenerateKeyRequest + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=_key_in_db + ) + + update_data.update(non_default_values) + # Update the token in the database updated_token = await prisma_client.db.litellm_verificationtoken.update( where={"token": hashed_api_key}, - data={ - "token": new_token_hash, - "key_name": new_token_key_name, - }, + data=update_data, # type: ignore ) + updated_token_dict = {} if updated_token is not None: updated_token_dict = dict(updated_token) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index adf0e8aea..995d5c0f7 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2946,108 +2946,6 @@ async def test_team_access_groups(prisma_client): ) -################ Unit Tests for testing regeneration of keys ########### -@pytest.mark.asyncio() -async def test_regenerate_api_key(prisma_client): - litellm.set_verbose = True - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - await litellm.proxy.proxy_server.prisma_client.connect() - import uuid - - # generate new key - key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" - spend = 100 - max_budget = 400 - models = ["fake-openai-endpoint"] - new_key = await generate_key_fn( - data=GenerateKeyRequest( - key_alias=key_alias, spend=spend, max_budget=max_budget, models=models - ), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - generated_key = new_key.key - print(generated_key) - - # assert the new key works as expected - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body - result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") - print(result) - - # regenerate the key - new_key = await regenerate_key_fn( - key=generated_key, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - print("response from regenerate_key_fn", new_key) - - # assert the new key works as expected - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body_2(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body_2 - result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}") - print(result) - - # assert the old key stops working - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body_3(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body_3 - try: - result = await user_api_key_auth( - request=request, api_key=f"Bearer {generated_key}" - ) - print(result) - pytest.fail(f"This should have failed!. the key has been regenerated") - except Exception as e: - print("got expected exception", e) - assert "Invalid proxy server token passed" in e.message - - # Check that the regenerated key has the same spend, max_budget, models and key_alias - assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}" - assert ( - new_key.max_budget == max_budget - ), f"Expected max_budget {max_budget} but got {new_key.max_budget}" - assert ( - new_key.key_alias == key_alias - ), f"Expected key_alias {key_alias} but got {new_key.key_alias}" - assert ( - new_key.models == models - ), f"Expected models {models} but got {new_key.models}" - - assert new_key.key_name == f"sk-...{new_key.key[-4:]}" - - pass - - @pytest.mark.asyncio() async def test_team_tags(prisma_client): """ diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py new file mode 100644 index 000000000..ddc3adcc8 --- /dev/null +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -0,0 +1,271 @@ +import os +import sys +import traceback +import uuid +import datetime as dt +from datetime import datetime + +from dotenv import load_dotenv +from fastapi import Request +from fastapi.routing import APIRoute + +load_dotenv() +import io +import os +import time + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import logging + +import pytest + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + new_user, + user_info, + user_update, +) +from litellm.proxy.management_endpoints.key_management_endpoints import ( + delete_key_fn, + generate_key_fn, + generate_key_helper_fn, + info_key_fn, + regenerate_key_fn, + update_key_fn, +) +from litellm.proxy.management_endpoints.team_endpoints import ( + new_team, + team_info, + update_team, +) +from litellm.proxy.proxy_server import ( + LitellmUserRoles, + audio_transcriptions, + chat_completion, + completion, + embeddings, + image_generation, + model_list, + moderations, + new_end_user, + user_api_key_auth, +) +from litellm.proxy.spend_tracking.spend_management_endpoints import ( + global_spend, + global_spend_logs, + global_spend_models, + global_spend_keys, + spend_key_fn, + spend_user_fn, + view_spend_logs, +) +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend + +verbose_proxy_logger.setLevel(level=logging.DEBUG) + +from starlette.datastructures import URL + +from litellm.caching import DualCache +from litellm.proxy._types import ( + DynamoDBArgs, + GenerateKeyRequest, + KeyRequest, + LiteLLM_UpperboundKeyGenerateParams, + NewCustomerRequest, + NewTeamRequest, + NewUserRequest, + ProxyErrorTypes, + ProxyException, + UpdateKeyRequest, + RegenerateKeyRequest, + UpdateTeamRequest, + UpdateUserRequest, + UserAPIKeyAuth, +) +from litellm.proxy.utils import DBClient + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +################ Unit Tests for testing regeneration of keys ########### +@pytest.mark.asyncio() +async def test_regenerate_api_key(prisma_client): + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + import uuid + + # generate new key + key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + # assert the new key works as expected + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body + result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") + print(result) + + # regenerate the key + new_key = await regenerate_key_fn( + key=generated_key, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print("response from regenerate_key_fn", new_key) + + # assert the new key works as expected + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body_2(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_2 + result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}") + print(result) + + # assert the old key stops working + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body_3(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_3 + try: + result = await user_api_key_auth( + request=request, api_key=f"Bearer {generated_key}" + ) + print(result) + pytest.fail(f"This should have failed!. the key has been regenerated") + except Exception as e: + print("got expected exception", e) + assert "Invalid proxy server token passed" in e.message + + # Check that the regenerated key has the same spend, max_budget, models and key_alias + assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}" + assert ( + new_key.max_budget == max_budget + ), f"Expected max_budget {max_budget} but got {new_key.max_budget}" + assert ( + new_key.key_alias == key_alias + ), f"Expected key_alias {key_alias} but got {new_key.key_alias}" + assert ( + new_key.models == models + ), f"Expected models {models} but got {new_key.models}" + + assert new_key.key_name == f"sk-...{new_key.key[-4:]}" + + pass + + +@pytest.mark.asyncio() +async def test_regenerate_api_key_with_new_alias_and_expiration(prisma_client): + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + import uuid + + # generate new key + key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + # regenerate the key with new alias and expiration + new_key = await regenerate_key_fn( + key=generated_key, + data=RegenerateKeyRequest( + key_alias="very_new_alias", + duration="30d", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print("response from regenerate_key_fn", new_key) + + # assert the alias and duration are updated + assert new_key.key_alias == "very_new_alias" + + # assert the new key expires 30 days from now + now = datetime.now(dt.timezone.utc) + assert new_key.expires > now + dt.timedelta(days=29) + assert new_key.expires < now + dt.timedelta(days=31) diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index a6bd5d32c..27d109699 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -771,7 +771,7 @@ export const claimOnboardingToken = async ( } }; -export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: string) => { +export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: string, formData: any) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/key/${keyToRegenerate}/regenerate` @@ -783,7 +783,7 @@ export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: st [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, - body: JSON.stringify({}), + body: JSON.stringify(formData), }); if (!response.ok) { diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index 70e8c5204..0595e31fd 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -1,6 +1,7 @@ "use client"; import React, { useEffect, useState } from "react"; import { keyDeleteCall, modelAvailableCall } from "./networking"; +import { add } from 'date-fns'; import { InformationCircleIcon, StatusOnlineIcon, TrashIcon, PencilAltIcon, RefreshIcon } from "@heroicons/react/outline"; import { keySpendLogsCall, PredictedSpendLogsCall, keyUpdateCall, modelInfoCall, regenerateKeyCall } from "./networking"; import { @@ -22,6 +23,7 @@ import { Subtitle, Icon, BarChart, + TextInput, } from "@tremor/react"; import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react"; import { @@ -33,7 +35,8 @@ import { InputNumber, message, Select, - Tooltip + Tooltip, + DatePicker, } from "antd"; import { CopyToClipboard } from "react-copy-to-clipboard"; @@ -120,9 +123,63 @@ const ViewKeyTable: React.FC = ({ const [modelLimitModalVisible, setModelLimitModalVisible] = useState(false); const [regenerateDialogVisible, setRegenerateDialogVisible] = useState(false); const [regeneratedKey, setRegeneratedKey] = useState(null); + const [regenerateFormData, setRegenerateFormData] = useState(null); + const [regenerateForm] = Form.useForm(); + const [newExpiryTime, setNewExpiryTime] = useState(null); const [knownTeamIDs, setKnownTeamIDs] = useState(initialKnownTeamIDs); + + useEffect(() => { + const calculateNewExpiryTime = (duration: string | undefined) => { + if (!duration) { + return null; + } + + try { + const now = new Date(); + let newExpiry: Date; + + if (duration.endsWith('s')) { + newExpiry = add(now, { seconds: parseInt(duration) }); + } else if (duration.endsWith('h')) { + newExpiry = add(now, { hours: parseInt(duration) }); + } else if (duration.endsWith('d')) { + newExpiry = add(now, { days: parseInt(duration) }); + } else { + throw new Error('Invalid duration format'); + } + + return newExpiry.toLocaleString('en-US', { + year: 'numeric', + month: 'numeric', + day: 'numeric', + hour: 'numeric', + minute: 'numeric', + second: 'numeric', + hour12: true + }); + } catch (error) { + return null; + } + }; + + console.log("in calculateNewExpiryTime for selectedToken", selectedToken); + + + // When a new duration is entered + if (regenerateFormData?.duration) { + setNewExpiryTime(calculateNewExpiryTime(regenerateFormData.duration)); + } else { + setNewExpiryTime(null); + } + + console.log("calculateNewExpiryTime:", newExpiryTime); + }, [selectedToken, regenerateFormData?.duration]); + + + + useEffect(() => { const fetchUserModels = async () => { try { @@ -146,6 +203,7 @@ const ViewKeyTable: React.FC = ({ fetchUserModels(); }, [accessToken, userID, userRole]); + const handleModelLimitClick = (token: ItemData) => { setSelectedToken(token); setModelLimitModalVisible(true); @@ -678,31 +736,53 @@ const ViewKeyTable: React.FC = ({ setKeyToDelete(null); }; + const handleRegenerateClick = (token: any) => { + setSelectedToken(token); + setNewExpiryTime(null); + regenerateForm.setFieldsValue({ + key_alias: token.key_alias, + max_budget: token.max_budget, + tpm_limit: token.tpm_limit, + rpm_limit: token.rpm_limit, + duration: token.duration || '', + }); + setRegenerateDialogVisible(true); + }; + + const handleRegenerateFormChange = (field: string, value: any) => { + setRegenerateFormData((prev: any) => ({ + ...prev, + [field]: value, + })); + }; + const handleRegenerateKey = async () => { if (!premiumUser) { message.error("Regenerate API Key is an Enterprise feature. Please upgrade to use this feature."); return; } + if (selectedToken == null) { + return; + } + try { - if (selectedToken == null) { - message.error("Please select a key to regenerate"); - return; - } - const response = await regenerateKeyCall(accessToken, selectedToken.token); + const formValues = await regenerateForm.validateFields(); + const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues); setRegeneratedKey(response.key); // Update the data state with the new key_name if (data) { const updatedData = data.map(item => - item.token === selectedToken.token - ? { ...item, key_name: response.key_name } + item.token === selectedToken?.token + ? { ...item, key_name: response.key_name, ...formValues } : item ); setData(updatedData); } setRegenerateDialogVisible(false); + regenerateForm.resetFields(); message.success("API Key regenerated successfully"); } catch (error) { console.error("Error regenerating key:", error); @@ -997,10 +1077,7 @@ const ViewKeyTable: React.FC = ({ onClick={() => handleEditClick(item)} /> { - setSelectedToken(item); - setRegenerateDialogVisible(true); - }} + onClick={() => handleRegenerateClick(item)} icon={RefreshIcon} size="sm" /> @@ -1080,13 +1157,19 @@ const ViewKeyTable: React.FC = ({ /> )} - {/* Regenerate Key Confirmation Dialog */} + {/* Regenerate Key Form Modal */} setRegenerateDialogVisible(false)} + onCancel={() => { + setRegenerateDialogVisible(false); + regenerateForm.resetFields(); + }} footer={[ - ,