Merge pull request #5379 from BerriAI/litellm_regen_keys_ui

[Feat-Proxy] Allow regenerating proxy virtual keys
This commit is contained in:
Ishaan Jaff 2024-08-26 18:59:42 -07:00 committed by GitHub
commit d963de4bf7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 382 additions and 4 deletions

View file

@ -1299,8 +1299,9 @@ class LiteLLM_VerificationToken(LiteLLMBase):
model_max_budget: Dict = {}
soft_budget_cooldown: bool = False
litellm_budget_table: Optional[dict] = None
org_id: Optional[str] = None # org id for a given key
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
model_config = ConfigDict(protected_namespaces=())

View file

@ -966,3 +966,96 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
return deleted_tokens
@router.post(
"/key/{key:path}/regenerate",
tags=["key management"],
dependencies=[Depends(user_api_key_auth)],
)
@management_endpoint_wrapper
async def regenerate_key_fn(
key: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
) -> GenerateKeyResponse:
from litellm.proxy.proxy_server import (
hash_token,
premium_user,
prisma_client,
user_api_key_cache,
)
"""
Endpoint for regenerating a key
"""
if premium_user is not True:
raise ValueError(
f"Regenerating Virtual Keys is an Enterprise feature, {CommonProxyErrors.not_premium_user.value}"
)
# Check if key exists, raise exception if key is not in the DB
### 1. Create New copy that is duplicate of existing key
######################################################################
# create duplicate of existing key
# set token = new token generated
# insert new token in DB
# create hash of token
if prisma_client is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": "DB not connected. prisma_client is None"},
)
if "sk" not in key:
hashed_api_key = key
else:
hashed_api_key = hash_token(key)
_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_api_key},
)
if _key_in_db is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"error": f"Key {key} not found."},
)
verbose_proxy_logger.debug("key_in_db: %s", _key_in_db)
new_token = f"sk-{secrets.token_urlsafe(16)}"
new_token_hash = hash_token(new_token)
new_token_key_name = f"sk-...{new_token[-4:]}"
# update new token in DB
updated_token = await prisma_client.db.litellm_verificationtoken.update(
where={"token": hashed_api_key},
data={
"token": new_token_hash,
"key_name": new_token_key_name,
},
)
updated_token_dict = {}
if updated_token is not None:
updated_token_dict = dict(updated_token)
updated_token_dict["token"] = new_token
### 3. remove existing key entry from cache
######################################################################
if key:
user_api_key_cache.delete_cache(key)
if hashed_api_key:
user_api_key_cache.delete_cache(hashed_api_key)
return GenerateKeyResponse(
**updated_token_dict,
)

View file

@ -149,6 +149,8 @@ model LiteLLM_VerificationToken {
model_max_budget Json @default("{}")
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
model LiteLLM_EndUserTable {

View file

@ -56,6 +56,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_fn,
generate_key_helper_fn,
info_key_fn,
regenerate_key_fn,
update_key_fn,
)
from litellm.proxy.management_endpoints.team_endpoints import (
@ -2935,3 +2936,105 @@ async def test_team_access_groups(prisma_client):
"not allowed to call model" in e.message
and "Allowed team models" in e.message
)
################ Unit Tests for testing regeneration of keys ###########
@pytest.mark.asyncio()
async def test_regenerate_api_key(prisma_client):
litellm.set_verbose = True
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
import uuid
# generate new key
key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}"
spend = 100
max_budget = 400
models = ["fake-openai-endpoint"]
new_key = await generate_key_fn(
data=GenerateKeyRequest(
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = new_key.key
print(generated_key)
# assert the new key works as expected
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body():
return_string = f'{{"model": "fake-openai-endpoint"}}'
# return string as bytes
return return_string.encode()
request.body = return_body
result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
print(result)
# regenerate the key
new_key = await regenerate_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print("response from regenerate_key_fn", new_key)
# assert the new key works as expected
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body_2():
return_string = f'{{"model": "fake-openai-endpoint"}}'
# return string as bytes
return return_string.encode()
request.body = return_body_2
result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}")
print(result)
# assert the old key stops working
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body_3():
return_string = f'{{"model": "fake-openai-endpoint"}}'
# return string as bytes
return return_string.encode()
request.body = return_body_3
try:
result = await user_api_key_auth(
request=request, api_key=f"Bearer {generated_key}"
)
print(result)
pytest.fail(f"This should have failed!. the key has been regenerated")
except Exception as e:
print("got expected exception", e)
assert "Invalid proxy server token passed" in e.message
# Check that the regenerated key has the same spend, max_budget, models and key_alias
assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}"
assert (
new_key.max_budget == max_budget
), f"Expected max_budget {max_budget} but got {new_key.max_budget}"
assert (
new_key.key_alias == key_alias
), f"Expected key_alias {key_alias} but got {new_key.key_alias}"
assert (
new_key.models == models
), f"Expected models {models} but got {new_key.models}"
assert new_key.key_name == f"sk-...{new_key.key[-4:]}"
pass

View file

@ -149,6 +149,8 @@ model LiteLLM_VerificationToken {
model_max_budget Json @default("{}")
budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
created_at DateTime @default(now()) @map("created_at")
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
}
model LiteLLM_EndUserTable {

View file

@ -141,6 +141,7 @@ const CreateKeyPage = () => {
<UserDashboard
userID={userID}
userRole={userRole}
premiumUser={premiumUser}
teams={teams}
keys={keys}
setUserRole={setUserRole}
@ -175,6 +176,7 @@ const CreateKeyPage = () => {
<UserDashboard
userID={userID}
userRole={userRole}
premiumUser={premiumUser}
teams={teams}
keys={keys}
setUserRole={setUserRole}

View file

@ -770,6 +770,37 @@ export const claimOnboardingToken = async (
throw error;
}
};
export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: string) => {
try {
const url = proxyBaseUrl
? `${proxyBaseUrl}/key/${keyToRegenerate}/regenerate`
: `/key/${keyToRegenerate}/regenerate`;
const response = await fetch(url, {
method: "POST",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({}),
});
if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log("Regenerate key Response:", data);
return data;
} catch (error) {
console.error("Failed to regenerate key:", error);
throw error;
}
};
let ModelListerrorShown = false;
let errorTimer: NodeJS.Timeout | null = null;

View file

@ -48,6 +48,7 @@ interface UserDashboardProps {
setKeys: React.Dispatch<React.SetStateAction<Object[] | null>>;
setProxySettings: React.Dispatch<React.SetStateAction<any>>;
proxySettings: any;
premiumUser: boolean;
}
type TeamInterface = {
@ -68,6 +69,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
setKeys,
setProxySettings,
proxySettings,
premiumUser,
}) => {
const [userSpendData, setUserSpendData] = useState<UserSpendData | null>(
null
@ -328,6 +330,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
selectedTeam={selectedTeam ? selectedTeam : null}
data={keys}
setData={setKeys}
premiumUser={premiumUser}
teams={teams}
/>
<CreateKey

View file

@ -1,12 +1,14 @@
"use client";
import React, { useEffect, useState } from "react";
import { keyDeleteCall, modelAvailableCall } from "./networking";
import { InformationCircleIcon, StatusOnlineIcon, TrashIcon, PencilAltIcon } from "@heroicons/react/outline";
import { keySpendLogsCall, PredictedSpendLogsCall, keyUpdateCall, modelInfoCall } from "./networking";
import { InformationCircleIcon, StatusOnlineIcon, TrashIcon, PencilAltIcon, RefreshIcon } from "@heroicons/react/outline";
import { keySpendLogsCall, PredictedSpendLogsCall, keyUpdateCall, modelInfoCall, regenerateKeyCall } from "./networking";
import {
Badge,
Card,
Table,
Grid,
Col,
Button,
TableBody,
TableCell,
@ -33,6 +35,8 @@ import {
Select,
} from "antd";
import { CopyToClipboard } from "react-copy-to-clipboard";
const { Option } = Select;
const isLocal = process.env.NODE_ENV === "development";
const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
@ -65,6 +69,7 @@ interface ViewKeyTableProps {
data: any[] | null;
setData: React.Dispatch<React.SetStateAction<any[] | null>>;
teams: any[] | null;
premiumUser: boolean;
}
interface ItemData {
@ -92,7 +97,8 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
selectedTeam,
data,
setData,
teams
teams,
premiumUser
}) => {
const [isButtonClicked, setIsButtonClicked] = useState(false);
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
@ -109,6 +115,8 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
const [userModels, setUserModels] = useState([]);
const initialKnownTeamIDs: Set<string> = new Set();
const [modelLimitModalVisible, setModelLimitModalVisible] = useState(false);
const [regenerateDialogVisible, setRegenerateDialogVisible] = useState(false);
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
const [knownTeamIDs, setKnownTeamIDs] = useState(initialKnownTeamIDs);
@ -612,6 +620,38 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
setKeyToDelete(null);
};
const handleRegenerateKey = async () => {
if (!premiumUser) {
message.error("Regenerate API Key is an Enterprise feature. Please upgrade to use this feature.");
return;
}
try {
if (selectedToken == null) {
message.error("Please select a key to regenerate");
return;
}
const response = await regenerateKeyCall(accessToken, selectedToken.token);
setRegeneratedKey(response.key);
// Update the data state with the new key_name
if (data) {
const updatedData = data.map(item =>
item.token === selectedToken.token
? { ...item, key_name: response.key_name }
: item
);
setData(updatedData);
}
setRegenerateDialogVisible(false);
message.success("API Key regenerated successfully");
} catch (error) {
console.error("Error regenerating key:", error);
message.error("Failed to regenerate API Key");
}
};
if (data == null) {
return;
}
@ -768,6 +808,7 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
size="sm"
/>
<Modal
open={infoDialogVisible}
@ -867,6 +908,14 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
size="sm"
onClick={() => handleEditClick(item)}
/>
<Icon
onClick={() => {
setSelectedToken(item);
setRegenerateDialogVisible(true);
}}
icon={RefreshIcon}
size="sm"
/>
<Icon
onClick={() => handleDelete(item)}
icon={TrashIcon}
@ -942,6 +991,98 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
accessToken={accessToken}
/>
)}
{/* Regenerate Key Confirmation Dialog */}
<Modal
title="Regenerate API Key"
visible={regenerateDialogVisible}
onCancel={() => setRegenerateDialogVisible(false)}
footer={[
<Button key="cancel" onClick={() => setRegenerateDialogVisible(false)} className="mr-2">
Cancel
</Button>,
<Button
key="regenerate"
onClick={handleRegenerateKey}
disabled={!premiumUser}
>
{premiumUser ? "Regenerate" : "Upgrade to Regenerate"}
</Button>
]}
>
{premiumUser ? (
<>
<p>Are you sure you want to regenerate this key?</p>
<p>Key Alias:</p>
<pre>{selectedToken?.key_alias || 'No alias set'}</pre>
</>
) : (
<div>
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to use this feature</p>
<Button variant="primary" className="mb-2">
<a href="https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" target="_blank">
Get Free Trial
</a>
</Button>
</div>
)}
</Modal>
{/* Regenerated Key Display Modal */}
{regeneratedKey && (
<Modal
visible={!!regeneratedKey}
onCancel={() => setRegeneratedKey(null)}
footer={[
<Button key="close" onClick={() => setRegeneratedKey(null)}>
Close
</Button>
]}
>
<Grid numItems={1} className="gap-2 w-full">
<Title>Regenerated Key</Title>
<Col numColSpan={1}>
<p>
Please replace your old key with the new key generated. For
security reasons, <b>you will not be able to view it again</b> through
your LiteLLM account. If you lose this secret key, you will need to
generate a new one.
</p>
</Col>
<Col numColSpan={1}>
<Text className="mt-3">Key Alias:</Text>
<div
style={{
background: "#f8f8f8",
padding: "10px",
borderRadius: "5px",
marginBottom: "10px",
}}
>
<pre style={{ wordWrap: "break-word", whiteSpace: "normal" }}>
{selectedToken?.key_alias || 'No alias set'}
</pre>
</div>
<Text className="mt-3">New API Key:</Text>
<div
style={{
background: "#f8f8f8",
padding: "10px",
borderRadius: "5px",
marginBottom: "10px",
}}
>
<pre style={{ wordWrap: "break-word", whiteSpace: "normal" }}>
{regeneratedKey}
</pre>
</div>
<CopyToClipboard text={regeneratedKey} onCopy={() => message.success("API Key copied to clipboard")}>
<Button className="mt-3">Copy API Key</Button>
</CopyToClipboard>
</Col>
</Grid>
</Modal>
)}
</div>
);
};