From aed59abe35a9bca887f9bc46f4153b6c808b61ee Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 08:36:34 -0700 Subject: [PATCH 1/8] allow passing expiry time to /key/regenerate --- litellm/proxy/_types.py | 4 + .../key_management_endpoints.py | 125 ++++---- litellm/tests/test_key_generate_prisma.py | 102 ------- .../test_key_management.py | 271 ++++++++++++++++++ 4 files changed, 342 insertions(+), 160 deletions(-) create mode 100644 tests/proxy_admin_ui_tests/test_key_management.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 67acf71e5..812b572fb 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -656,6 +656,10 @@ class UpdateKeyRequest(GenerateKeyRequest): metadata: Optional[dict] = None +class RegenerateKeyRequest(UpdateKeyRequest): + key: Optional[str] = None + + class KeyRequest(LiteLLMBase): keys: List[str] diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 00e17400c..43f4c1fcb 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -280,6 +280,52 @@ async def generate_key_fn( ) +async def prepare_key_update_data(data: UpdateKeyRequest, existing_key_row): + data_json: dict = data.dict(exclude_unset=True) + key = data_json.pop("key", None) + + _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] + non_default_values = {} + for k, v in data_json.items(): + if k in _metadata_fields: + continue + if v is not None and v not in ([], {}, 0): + non_default_values[k] = v + + if "duration" in non_default_values: + duration = non_default_values.pop("duration") + duration_s = _duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["expires"] = expires + + if "budget_duration" in non_default_values: + duration_s = _duration_in_seconds( + duration=non_default_values["budget_duration"] + ) + key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["budget_reset_at"] = key_reset_at + + _metadata = existing_key_row.metadata or {} + + if data.model_tpm_limit: + if "model_tpm_limit" not in _metadata: + _metadata["model_tpm_limit"] = {} + _metadata["model_tpm_limit"].update(data.model_tpm_limit) + non_default_values["metadata"] = _metadata + + if data.model_rpm_limit: + if "model_rpm_limit" not in _metadata: + _metadata["model_rpm_limit"] = {} + _metadata["model_rpm_limit"].update(data.model_rpm_limit) + non_default_values["metadata"] = _metadata + + if data.guardrails: + _metadata["guardrails"] = data.guardrails + non_default_values["metadata"] = _metadata + + return non_default_values + + @router.post( "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) @@ -323,59 +369,9 @@ async def update_key_fn( detail={"error": f"Team not found, passed team_id={data.team_id}"}, ) - _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"] - # get non default values for key - non_default_values = {} - for k, v in data_json.items(): - # this field gets stored in metadata - if key in _metadata_fields: - continue - if v is not None and v not in ( - [], - {}, - 0, - ): # models default to [], spend defaults to 0, we should not reset these values - non_default_values[k] = v - - if "duration" in non_default_values: - duration = non_default_values.pop("duration") - duration_s = _duration_in_seconds(duration=duration) - expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["expires"] = expires - - if "budget_duration" in non_default_values: - duration_s = _duration_in_seconds( - duration=non_default_values["budget_duration"] - ) - key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["budget_reset_at"] = key_reset_at - - # Update metadata for virtual Key - if data.model_tpm_limit: - _metadata = existing_key_row.metadata or {} - if "model_tpm_limit" not in _metadata: - _metadata["model_tpm_limit"] = {} - - _metadata["model_tpm_limit"].update(data.model_tpm_limit) - non_default_values["metadata"] = _metadata - non_default_values.pop("model_tpm_limit", None) - - if data.model_rpm_limit: - _metadata = existing_key_row.metadata or {} - if "model_rpm_limit" not in _metadata: - _metadata["model_rpm_limit"] = {} - - _metadata["model_rpm_limit"].update(data.model_rpm_limit) - non_default_values["metadata"] = _metadata - non_default_values.pop("model_rpm_limit", None) - - if data.guardrails: - _metadata = existing_key_row.metadata or {} - _metadata["guardrails"] = data.guardrails - - # update values that will be written to the DB - non_default_values["metadata"] = _metadata - non_default_values.pop("guardrails", None) + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=existing_key_row + ) response = await prisma_client.update_data( token=key, data={**non_default_values, "token": key} @@ -983,6 +979,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None) @management_endpoint_wrapper async def regenerate_key_fn( key: str, + data: Optional[RegenerateKeyRequest] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), litellm_changed_by: Optional[str] = Header( None, @@ -1041,14 +1038,26 @@ async def regenerate_key_fn( new_token_hash = hash_token(new_token) new_token_key_name = f"sk-...{new_token[-4:]}" - # update new token in DB + # Prepare the update data + update_data = { + "token": new_token_hash, + "key_name": new_token_key_name, + } + + non_default_values = {} + if data is not None: + # Update with any provided parameters from GenerateKeyRequest + non_default_values = await prepare_key_update_data( + data=data, existing_key_row=_key_in_db + ) + + update_data.update(non_default_values) + # Update the token in the database updated_token = await prisma_client.db.litellm_verificationtoken.update( where={"token": hashed_api_key}, - data={ - "token": new_token_hash, - "key_name": new_token_key_name, - }, + data=update_data, ) + updated_token_dict = {} if updated_token is not None: updated_token_dict = dict(updated_token) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index adf0e8aea..995d5c0f7 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2946,108 +2946,6 @@ async def test_team_access_groups(prisma_client): ) -################ Unit Tests for testing regeneration of keys ########### -@pytest.mark.asyncio() -async def test_regenerate_api_key(prisma_client): - litellm.set_verbose = True - setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) - setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - await litellm.proxy.proxy_server.prisma_client.connect() - import uuid - - # generate new key - key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" - spend = 100 - max_budget = 400 - models = ["fake-openai-endpoint"] - new_key = await generate_key_fn( - data=GenerateKeyRequest( - key_alias=key_alias, spend=spend, max_budget=max_budget, models=models - ), - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - - generated_key = new_key.key - print(generated_key) - - # assert the new key works as expected - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body - result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") - print(result) - - # regenerate the key - new_key = await regenerate_key_fn( - key=generated_key, - user_api_key_dict=UserAPIKeyAuth( - user_role=LitellmUserRoles.PROXY_ADMIN, - api_key="sk-1234", - user_id="1234", - ), - ) - print("response from regenerate_key_fn", new_key) - - # assert the new key works as expected - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body_2(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body_2 - result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}") - print(result) - - # assert the old key stops working - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") - - async def return_body_3(): - return_string = f'{{"model": "fake-openai-endpoint"}}' - # return string as bytes - return return_string.encode() - - request.body = return_body_3 - try: - result = await user_api_key_auth( - request=request, api_key=f"Bearer {generated_key}" - ) - print(result) - pytest.fail(f"This should have failed!. the key has been regenerated") - except Exception as e: - print("got expected exception", e) - assert "Invalid proxy server token passed" in e.message - - # Check that the regenerated key has the same spend, max_budget, models and key_alias - assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}" - assert ( - new_key.max_budget == max_budget - ), f"Expected max_budget {max_budget} but got {new_key.max_budget}" - assert ( - new_key.key_alias == key_alias - ), f"Expected key_alias {key_alias} but got {new_key.key_alias}" - assert ( - new_key.models == models - ), f"Expected models {models} but got {new_key.models}" - - assert new_key.key_name == f"sk-...{new_key.key[-4:]}" - - pass - - @pytest.mark.asyncio() async def test_team_tags(prisma_client): """ diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py new file mode 100644 index 000000000..ddc3adcc8 --- /dev/null +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -0,0 +1,271 @@ +import os +import sys +import traceback +import uuid +import datetime as dt +from datetime import datetime + +from dotenv import load_dotenv +from fastapi import Request +from fastapi.routing import APIRoute + +load_dotenv() +import io +import os +import time + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import logging + +import pytest + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + new_user, + user_info, + user_update, +) +from litellm.proxy.management_endpoints.key_management_endpoints import ( + delete_key_fn, + generate_key_fn, + generate_key_helper_fn, + info_key_fn, + regenerate_key_fn, + update_key_fn, +) +from litellm.proxy.management_endpoints.team_endpoints import ( + new_team, + team_info, + update_team, +) +from litellm.proxy.proxy_server import ( + LitellmUserRoles, + audio_transcriptions, + chat_completion, + completion, + embeddings, + image_generation, + model_list, + moderations, + new_end_user, + user_api_key_auth, +) +from litellm.proxy.spend_tracking.spend_management_endpoints import ( + global_spend, + global_spend_logs, + global_spend_models, + global_spend_keys, + spend_key_fn, + spend_user_fn, + view_spend_logs, +) +from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend + +verbose_proxy_logger.setLevel(level=logging.DEBUG) + +from starlette.datastructures import URL + +from litellm.caching import DualCache +from litellm.proxy._types import ( + DynamoDBArgs, + GenerateKeyRequest, + KeyRequest, + LiteLLM_UpperboundKeyGenerateParams, + NewCustomerRequest, + NewTeamRequest, + NewUserRequest, + ProxyErrorTypes, + ProxyException, + UpdateKeyRequest, + RegenerateKeyRequest, + UpdateTeamRequest, + UpdateUserRequest, + UserAPIKeyAuth, +) +from litellm.proxy.utils import DBClient + +proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) + + +@pytest.fixture +def prisma_client(): + from litellm.proxy.proxy_cli import append_query_params + + ### add connection pool + pool timeout args + params = {"connection_limit": 100, "pool_timeout": 60} + database_url = os.getenv("DATABASE_URL") + modified_url = append_query_params(database_url, params) + os.environ["DATABASE_URL"] = modified_url + + # Assuming DBClient is a class that needs to be instantiated + prisma_client = PrismaClient( + database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj + ) + + # Reset litellm.proxy.proxy_server.prisma_client to None + litellm.proxy.proxy_server.custom_db_client = None + litellm.proxy.proxy_server.litellm_proxy_budget_name = ( + f"litellm-proxy-budget-{time.time()}" + ) + litellm.proxy.proxy_server.user_custom_key_generate = None + + return prisma_client + + +################ Unit Tests for testing regeneration of keys ########### +@pytest.mark.asyncio() +async def test_regenerate_api_key(prisma_client): + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + import uuid + + # generate new key + key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + # assert the new key works as expected + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body + result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}") + print(result) + + # regenerate the key + new_key = await regenerate_key_fn( + key=generated_key, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print("response from regenerate_key_fn", new_key) + + # assert the new key works as expected + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body_2(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_2 + result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}") + print(result) + + # assert the old key stops working + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + async def return_body_3(): + return_string = f'{{"model": "fake-openai-endpoint"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_3 + try: + result = await user_api_key_auth( + request=request, api_key=f"Bearer {generated_key}" + ) + print(result) + pytest.fail(f"This should have failed!. the key has been regenerated") + except Exception as e: + print("got expected exception", e) + assert "Invalid proxy server token passed" in e.message + + # Check that the regenerated key has the same spend, max_budget, models and key_alias + assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}" + assert ( + new_key.max_budget == max_budget + ), f"Expected max_budget {max_budget} but got {new_key.max_budget}" + assert ( + new_key.key_alias == key_alias + ), f"Expected key_alias {key_alias} but got {new_key.key_alias}" + assert ( + new_key.models == models + ), f"Expected models {models} but got {new_key.models}" + + assert new_key.key_name == f"sk-...{new_key.key[-4:]}" + + pass + + +@pytest.mark.asyncio() +async def test_regenerate_api_key_with_new_alias_and_expiration(prisma_client): + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + import uuid + + # generate new key + key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + # regenerate the key with new alias and expiration + new_key = await regenerate_key_fn( + key=generated_key, + data=RegenerateKeyRequest( + key_alias="very_new_alias", + duration="30d", + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print("response from regenerate_key_fn", new_key) + + # assert the alias and duration are updated + assert new_key.key_alias == "very_new_alias" + + # assert the new key expires 30 days from now + now = datetime.now(dt.timezone.utc) + assert new_key.expires > now + dt.timedelta(days=29) + assert new_key.expires < now + dt.timedelta(days=31) From 9bdb96ad9cf383d8635595bec97afbdb4c1692eb Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 10:19:37 -0700 Subject: [PATCH 2/8] allow correct fields on regenerate key --- .../src/components/view_key_table.tsx | 143 +++++++++++++++--- 1 file changed, 125 insertions(+), 18 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index 70e8c5204..659567949 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -22,6 +22,7 @@ import { Subtitle, Icon, BarChart, + TextInput, } from "@tremor/react"; import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react"; import { @@ -33,7 +34,8 @@ import { InputNumber, message, Select, - Tooltip + Tooltip, + DatePicker, } from "antd"; import { CopyToClipboard } from "react-copy-to-clipboard"; @@ -120,6 +122,7 @@ const ViewKeyTable: React.FC = ({ const [modelLimitModalVisible, setModelLimitModalVisible] = useState(false); const [regenerateDialogVisible, setRegenerateDialogVisible] = useState(false); const [regeneratedKey, setRegeneratedKey] = useState(null); + const [regenerateFormData, setRegenerateFormData] = useState(null); const [knownTeamIDs, setKnownTeamIDs] = useState(initialKnownTeamIDs); @@ -146,6 +149,42 @@ const ViewKeyTable: React.FC = ({ fetchUserModels(); }, [accessToken, userID, userRole]); + const [newExpiryTime, setNewExpiryTime] = useState(null); + + useEffect(() => { + if (regenerateFormData?.duration) { + try { + const now = new Date(); + const duration = regenerateFormData.duration; + let newExpiry: Date; + + if (duration.endsWith('s')) { + newExpiry = add(now, { seconds: parseInt(duration) }); + } else if (duration.endsWith('h')) { + newExpiry = add(now, { hours: parseInt(duration) }); + } else if (duration.endsWith('d')) { + newExpiry = add(now, { days: parseInt(duration) }); + } else { + throw new Error('Invalid duration format'); + } + + setNewExpiryTime(newExpiry.toLocaleString('en-US', { + year: 'numeric', + month: 'numeric', + day: 'numeric', + hour: 'numeric', + minute: 'numeric', + second: 'numeric', + hour12: true + })); + } catch (error) { + setNewExpiryTime(null); + } + } else { + setNewExpiryTime(null); + } + }, [regenerateFormData?.duration]); + const handleModelLimitClick = (token: ItemData) => { setSelectedToken(token); setModelLimitModalVisible(true); @@ -678,6 +717,22 @@ const ViewKeyTable: React.FC = ({ setKeyToDelete(null); }; + const handleRegenerateClick = (token: any) => { + setSelectedToken(token); + setRegenerateFormData({ + ...token, + duration: token.duration || 'none', // Set a default value if duration is not present + }); + setRegenerateDialogVisible(true); + }; + + const handleRegenerateFormChange = (field: string, value: any) => { + setRegenerateFormData((prev: any) => ({ + ...prev, + [field]: value, + })); + }; + const handleRegenerateKey = async () => { if (!premiumUser) { message.error("Regenerate API Key is an Enterprise feature. Please upgrade to use this feature."); @@ -685,24 +740,25 @@ const ViewKeyTable: React.FC = ({ } try { - if (selectedToken == null) { - message.error("Please select a key to regenerate"); + if (regenerateFormData == null) { + message.error("Please fill in the key details"); return; } - const response = await regenerateKeyCall(accessToken, selectedToken.token); + const response = await regenerateKeyCall(accessToken, regenerateFormData); setRegeneratedKey(response.key); // Update the data state with the new key_name if (data) { const updatedData = data.map(item => - item.token === selectedToken.token - ? { ...item, key_name: response.key_name } + item.token === selectedToken?.token + ? { ...item, key_name: response.key_name, ...regenerateFormData } : item ); setData(updatedData); } setRegenerateDialogVisible(false); + setRegenerateFormData(null); message.success("API Key regenerated successfully"); } catch (error) { console.error("Error regenerating key:", error); @@ -997,10 +1053,7 @@ const ViewKeyTable: React.FC = ({ onClick={() => handleEditClick(item)} /> { - setSelectedToken(item); - setRegenerateDialogVisible(true); - }} + onClick={() => handleRegenerateClick(item)} icon={RefreshIcon} size="sm" /> @@ -1080,13 +1133,19 @@ const ViewKeyTable: React.FC = ({ /> )} - {/* Regenerate Key Confirmation Dialog */} + {/* Regenerate Key Form Modal */} setRegenerateDialogVisible(false)} + onCancel={() => { + setRegenerateDialogVisible(false); + setRegenerateFormData(null); + }} footer={[ - , , @@ -1158,58 +1159,42 @@ const ViewKeyTable: React.FC = ({ ]} > {premiumUser ? ( -
- - handleRegenerateFormChange('key_alias', e.target.value)} - /> + + + - - handleRegenerateFormChange('max_budget', value)} - /> + + - - handleRegenerateFormChange('tpm_limit', value)} - /> + + - - handleRegenerateFormChange('rpm_limit', value)} - /> + + - handleRegenerateFormChange('duration', e.target.value)} - /> -
- Current expiry: { - selectedToken?.expires != null ? ( - new Date(selectedToken.expires).toLocaleString(undefined, { - day: 'numeric', - month: 'long', - year: 'numeric', - hour: 'numeric', - minute: 'numeric', - second: 'numeric' - }) - ) : ( - 'Never' - ) - } -
-
+ name="duration" + label="Expire Key (eg: 30s, 30h, 30d)" + className="mt-8" + > + +
+
+ Current expiry: { + selectedToken?.expires != null ? ( + new Date(selectedToken.expires).toLocaleString(undefined, { + day: 'numeric', + month: 'long', + year: 'numeric', + hour: 'numeric', + minute: 'numeric', + second: 'numeric' + }) + ) : ( + 'Never' + ) + } +
) : (
From 5c15e21a0562618a2424d3f8ff3af9ef5d4e468c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 10:34:50 -0700 Subject: [PATCH 4/8] ui allow rotating keys --- .../src/components/view_key_table.tsx | 76 +++++++++++-------- 1 file changed, 43 insertions(+), 33 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index 4116e033b..d22e844d4 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -1,6 +1,7 @@ "use client"; import React, { useEffect, useState } from "react"; import { keyDeleteCall, modelAvailableCall } from "./networking"; +import { add } from 'date-fns'; import { InformationCircleIcon, StatusOnlineIcon, TrashIcon, PencilAltIcon, RefreshIcon } from "@heroicons/react/outline"; import { keySpendLogsCall, PredictedSpendLogsCall, keyUpdateCall, modelInfoCall, regenerateKeyCall } from "./networking"; import { @@ -124,33 +125,10 @@ const ViewKeyTable: React.FC = ({ const [regeneratedKey, setRegeneratedKey] = useState(null); const [regenerateFormData, setRegenerateFormData] = useState(null); const [regenerateForm] = Form.useForm(); + const [newExpiryTime, setNewExpiryTime] = useState(null); const [knownTeamIDs, setKnownTeamIDs] = useState(initialKnownTeamIDs); - useEffect(() => { - const fetchUserModels = async () => { - try { - if (userID === null) { - return; - } - - if (accessToken !== null && userRole !== null) { - const model_available = await modelAvailableCall(accessToken, userID, userRole); - let available_model_names = model_available["data"].map( - (element: { id: string }) => element.id - ); - console.log("available_model_names:", available_model_names); - setUserModels(available_model_names); - } - } catch (error) { - console.error("Error fetching user models:", error); - } - }; - - fetchUserModels(); - }, [accessToken, userID, userRole]); - - const [newExpiryTime, setNewExpiryTime] = useState(null); useEffect(() => { if (regenerateFormData?.duration) { @@ -186,6 +164,32 @@ const ViewKeyTable: React.FC = ({ } }, [regenerateFormData?.duration]); + + + useEffect(() => { + const fetchUserModels = async () => { + try { + if (userID === null) { + return; + } + + if (accessToken !== null && userRole !== null) { + const model_available = await modelAvailableCall(accessToken, userID, userRole); + let available_model_names = model_available["data"].map( + (element: { id: string }) => element.id + ); + console.log("available_model_names:", available_model_names); + setUserModels(available_model_names); + } + } catch (error) { + console.error("Error fetching user models:", error); + } + }; + + fetchUserModels(); + }, [accessToken, userID, userRole]); + + const handleModelLimitClick = (token: ItemData) => { setSelectedToken(token); setModelLimitModalVisible(true); @@ -1159,7 +1163,15 @@ const ViewKeyTable: React.FC = ({ ]} > {premiumUser ? ( -
+ { + if ('duration' in changedValues) { + handleRegenerateFormChange('duration', changedValues.duration); + } + }} + > @@ -1182,19 +1194,17 @@ const ViewKeyTable: React.FC = ({
Current expiry: { selectedToken?.expires != null ? ( - new Date(selectedToken.expires).toLocaleString(undefined, { - day: 'numeric', - month: 'long', - year: 'numeric', - hour: 'numeric', - minute: 'numeric', - second: 'numeric' - }) + new Date(selectedToken.expires).toLocaleString() ) : ( 'Never' ) }
+ {newExpiryTime && ( +
+ New expiry: {newExpiryTime} +
+ )} ) : (
From 5d9f2e7115d546d7359b0a4641197b29beba5e0d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 14:10:02 -0700 Subject: [PATCH 5/8] working regen flow --- ui/litellm-dashboard/src/components/networking.tsx | 4 ++-- ui/litellm-dashboard/src/components/view_key_table.tsx | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index a6bd5d32c..27d109699 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -771,7 +771,7 @@ export const claimOnboardingToken = async ( } }; -export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: string) => { +export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: string, formData: any) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/key/${keyToRegenerate}/regenerate` @@ -783,7 +783,7 @@ export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: st [globalLitellmHeaderName]: `Bearer ${accessToken}`, "Content-Type": "application/json", }, - body: JSON.stringify({}), + body: JSON.stringify(formData), }); if (!response.ok) { diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index d22e844d4..5c482fcd4 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -749,7 +749,7 @@ const ViewKeyTable: React.FC = ({ try { const formValues = await regenerateForm.validateFields(); - const response = await regenerateKeyCall(accessToken, { ...formValues, token: selectedToken?.token }); + const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues); setRegeneratedKey(response.key); // Update the data state with the new key_name @@ -1173,7 +1173,7 @@ const ViewKeyTable: React.FC = ({ }} > - + From 46b5a78c2195e2ac46e3c405b2827f6a7661bf32 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 14:31:20 -0700 Subject: [PATCH 6/8] fix regen keys --- .../src/components/view_key_table.tsx | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index 5c482fcd4..4f3238806 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -131,12 +131,15 @@ const ViewKeyTable: React.FC = ({ useEffect(() => { - if (regenerateFormData?.duration) { + const calculateNewExpiryTime = (duration: string | undefined) => { + if (!duration) { + return null; + } + try { const now = new Date(); - const duration = regenerateFormData.duration; let newExpiry: Date; - + if (duration.endsWith('s')) { newExpiry = add(now, { seconds: parseInt(duration) }); } else if (duration.endsWith('h')) { @@ -146,8 +149,8 @@ const ViewKeyTable: React.FC = ({ } else { throw new Error('Invalid duration format'); } - - setNewExpiryTime(newExpiry.toLocaleString('en-US', { + + return newExpiry.toLocaleString('en-US', { year: 'numeric', month: 'numeric', day: 'numeric', @@ -155,14 +158,25 @@ const ViewKeyTable: React.FC = ({ minute: 'numeric', second: 'numeric', hour12: true - })); + }); } catch (error) { - setNewExpiryTime(null); + return null; } + }; + + console.log("in calculateNewExpiryTime for selectedToken", selectedToken); + + + // When a new duration is entered + if (regenerateFormData?.duration) { + setNewExpiryTime(calculateNewExpiryTime(regenerateFormData.duration)); } else { setNewExpiryTime(null); } - }, [regenerateFormData?.duration]); + + console.log("calculateNewExpiryTime:", newExpiryTime); + }, [selectedToken, regenerateFormData?.duration]); + @@ -724,6 +738,7 @@ const ViewKeyTable: React.FC = ({ const handleRegenerateClick = (token: any) => { setSelectedToken(token); + setNewExpiryTime(null); regenerateForm.setFieldsValue({ key_alias: token.key_alias, max_budget: token.max_budget, @@ -1166,7 +1181,7 @@ const ViewKeyTable: React.FC = ({
{ + onValuesChange={(changedValues, allValues) => { if ('duration' in changedValues) { handleRegenerateFormChange('duration', changedValues.duration); } From 7f461dbf68f3e6551236a389f1aa5f70c1c9fb61 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 16:54:43 -0700 Subject: [PATCH 7/8] fix linting --- litellm/proxy/_types.py | 6 +++++- .../proxy/management_endpoints/key_management_endpoints.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 812b572fb..2bbd14a16 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -656,8 +656,12 @@ class UpdateKeyRequest(GenerateKeyRequest): metadata: Optional[dict] = None -class RegenerateKeyRequest(UpdateKeyRequest): +class RegenerateKeyRequest(GenerateKeyRequest): + # This needs to be different from UpdateKeyRequest, because "key" is optional for this key: Optional[str] = None + duration: Optional[str] = None + spend: Optional[float] = None + metadata: Optional[dict] = None class KeyRequest(LiteLLMBase): diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 43f4c1fcb..553cdb177 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -280,7 +280,9 @@ async def generate_key_fn( ) -async def prepare_key_update_data(data: UpdateKeyRequest, existing_key_row): +async def prepare_key_update_data( + data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row +): data_json: dict = data.dict(exclude_unset=True) key = data_json.pop("key", None) @@ -1055,7 +1057,7 @@ async def regenerate_key_fn( # Update the token in the database updated_token = await prisma_client.db.litellm_verificationtoken.update( where={"token": hashed_api_key}, - data=update_data, + data=update_data, # type: ignore ) updated_token_dict = {} From b345db5011b577196aa0fc090e7812f8b8f20a61 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 6 Sep 2024 17:04:03 -0700 Subject: [PATCH 8/8] fix ui type --- ui/litellm-dashboard/src/components/view_key_table.tsx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index 4f3238806..0595e31fd 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -762,6 +762,10 @@ const ViewKeyTable: React.FC = ({ return; } + if (selectedToken == null) { + return; + } + try { const formValues = await regenerateForm.validateFields(); const response = await regenerateKeyCall(accessToken, selectedToken.token, formValues);