forked from phoenix/litellm-mirror
Merge pull request #5379 from BerriAI/litellm_regen_keys_ui
[Feat-Proxy] Allow regenerating proxy virtual keys
This commit is contained in:
commit
d963de4bf7
9 changed files with 382 additions and 4 deletions
|
@ -1299,8 +1299,9 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
|||
model_max_budget: Dict = {}
|
||||
soft_budget_cooldown: bool = False
|
||||
litellm_budget_table: Optional[dict] = None
|
||||
|
||||
org_id: Optional[str] = None # org id for a given key
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
|
|
@ -966,3 +966,96 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
|||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
return deleted_tokens
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/{key:path}/regenerate",
|
||||
tags=["key management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def regenerate_key_fn(
|
||||
key: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
) -> GenerateKeyResponse:
|
||||
from litellm.proxy.proxy_server import (
|
||||
hash_token,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
"""
|
||||
Endpoint for regenerating a key
|
||||
"""
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Regenerating Virtual Keys is an Enterprise feature, {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
# Check if key exists, raise exception if key is not in the DB
|
||||
|
||||
### 1. Create New copy that is duplicate of existing key
|
||||
######################################################################
|
||||
|
||||
# create duplicate of existing key
|
||||
# set token = new token generated
|
||||
# insert new token in DB
|
||||
|
||||
# create hash of token
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "DB not connected. prisma_client is None"},
|
||||
)
|
||||
|
||||
if "sk" not in key:
|
||||
hashed_api_key = key
|
||||
else:
|
||||
hashed_api_key = hash_token(key)
|
||||
|
||||
_key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_api_key},
|
||||
)
|
||||
if _key_in_db is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"error": f"Key {key} not found."},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("key_in_db: %s", _key_in_db)
|
||||
|
||||
new_token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
new_token_hash = hash_token(new_token)
|
||||
new_token_key_name = f"sk-...{new_token[-4:]}"
|
||||
|
||||
# update new token in DB
|
||||
updated_token = await prisma_client.db.litellm_verificationtoken.update(
|
||||
where={"token": hashed_api_key},
|
||||
data={
|
||||
"token": new_token_hash,
|
||||
"key_name": new_token_key_name,
|
||||
},
|
||||
)
|
||||
updated_token_dict = {}
|
||||
if updated_token is not None:
|
||||
updated_token_dict = dict(updated_token)
|
||||
|
||||
updated_token_dict["token"] = new_token
|
||||
|
||||
### 3. remove existing key entry from cache
|
||||
######################################################################
|
||||
if key:
|
||||
user_api_key_cache.delete_cache(key)
|
||||
|
||||
if hashed_api_key:
|
||||
user_api_key_cache.delete_cache(hashed_api_key)
|
||||
|
||||
return GenerateKeyResponse(
|
||||
**updated_token_dict,
|
||||
)
|
||||
|
|
|
@ -149,6 +149,8 @@ model LiteLLM_VerificationToken {
|
|||
model_max_budget Json @default("{}")
|
||||
budget_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
}
|
||||
|
||||
model LiteLLM_EndUserTable {
|
||||
|
|
|
@ -56,6 +56,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
|||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
info_key_fn,
|
||||
regenerate_key_fn,
|
||||
update_key_fn,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
|
@ -2935,3 +2936,105 @@ async def test_team_access_groups(prisma_client):
|
|||
"not allowed to call model" in e.message
|
||||
and "Allowed team models" in e.message
|
||||
)
|
||||
|
||||
|
||||
################ Unit Tests for testing regeneration of keys ###########
|
||||
@pytest.mark.asyncio()
|
||||
async def test_regenerate_api_key(prisma_client):
|
||||
litellm.set_verbose = True
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
import uuid
|
||||
|
||||
# generate new key
|
||||
key_alias = f"test_alias_regenerate_key-{uuid.uuid4()}"
|
||||
spend = 100
|
||||
max_budget = 400
|
||||
models = ["fake-openai-endpoint"]
|
||||
new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
generated_key = new_key.key
|
||||
print(generated_key)
|
||||
|
||||
# assert the new key works as expected
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
async def return_body():
|
||||
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
||||
# return string as bytes
|
||||
return return_string.encode()
|
||||
|
||||
request.body = return_body
|
||||
result = await user_api_key_auth(request=request, api_key=f"Bearer {generated_key}")
|
||||
print(result)
|
||||
|
||||
# regenerate the key
|
||||
new_key = await regenerate_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print("response from regenerate_key_fn", new_key)
|
||||
|
||||
# assert the new key works as expected
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
async def return_body_2():
|
||||
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
||||
# return string as bytes
|
||||
return return_string.encode()
|
||||
|
||||
request.body = return_body_2
|
||||
result = await user_api_key_auth(request=request, api_key=f"Bearer {new_key.key}")
|
||||
print(result)
|
||||
|
||||
# assert the old key stops working
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
async def return_body_3():
|
||||
return_string = f'{{"model": "fake-openai-endpoint"}}'
|
||||
# return string as bytes
|
||||
return return_string.encode()
|
||||
|
||||
request.body = return_body_3
|
||||
try:
|
||||
result = await user_api_key_auth(
|
||||
request=request, api_key=f"Bearer {generated_key}"
|
||||
)
|
||||
print(result)
|
||||
pytest.fail(f"This should have failed!. the key has been regenerated")
|
||||
except Exception as e:
|
||||
print("got expected exception", e)
|
||||
assert "Invalid proxy server token passed" in e.message
|
||||
|
||||
# Check that the regenerated key has the same spend, max_budget, models and key_alias
|
||||
assert new_key.spend == spend, f"Expected spend {spend} but got {new_key.spend}"
|
||||
assert (
|
||||
new_key.max_budget == max_budget
|
||||
), f"Expected max_budget {max_budget} but got {new_key.max_budget}"
|
||||
assert (
|
||||
new_key.key_alias == key_alias
|
||||
), f"Expected key_alias {key_alias} but got {new_key.key_alias}"
|
||||
assert (
|
||||
new_key.models == models
|
||||
), f"Expected models {models} but got {new_key.models}"
|
||||
|
||||
assert new_key.key_name == f"sk-...{new_key.key[-4:]}"
|
||||
|
||||
pass
|
||||
|
|
|
@ -149,6 +149,8 @@ model LiteLLM_VerificationToken {
|
|||
model_max_budget Json @default("{}")
|
||||
budget_id String?
|
||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
}
|
||||
|
||||
model LiteLLM_EndUserTable {
|
||||
|
|
|
@ -141,6 +141,7 @@ const CreateKeyPage = () => {
|
|||
<UserDashboard
|
||||
userID={userID}
|
||||
userRole={userRole}
|
||||
premiumUser={premiumUser}
|
||||
teams={teams}
|
||||
keys={keys}
|
||||
setUserRole={setUserRole}
|
||||
|
@ -175,6 +176,7 @@ const CreateKeyPage = () => {
|
|||
<UserDashboard
|
||||
userID={userID}
|
||||
userRole={userRole}
|
||||
premiumUser={premiumUser}
|
||||
teams={teams}
|
||||
keys={keys}
|
||||
setUserRole={setUserRole}
|
||||
|
|
|
@ -770,6 +770,37 @@ export const claimOnboardingToken = async (
|
|||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: string) => {
|
||||
try {
|
||||
const url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/key/${keyToRegenerate}/regenerate`
|
||||
: `/key/${keyToRegenerate}/regenerate`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
handleError(errorData);
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Regenerate key Response:", data);
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error("Failed to regenerate key:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
let ModelListerrorShown = false;
|
||||
let errorTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ interface UserDashboardProps {
|
|||
setKeys: React.Dispatch<React.SetStateAction<Object[] | null>>;
|
||||
setProxySettings: React.Dispatch<React.SetStateAction<any>>;
|
||||
proxySettings: any;
|
||||
premiumUser: boolean;
|
||||
}
|
||||
|
||||
type TeamInterface = {
|
||||
|
@ -68,6 +69,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
setKeys,
|
||||
setProxySettings,
|
||||
proxySettings,
|
||||
premiumUser,
|
||||
}) => {
|
||||
const [userSpendData, setUserSpendData] = useState<UserSpendData | null>(
|
||||
null
|
||||
|
@ -328,6 +330,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
selectedTeam={selectedTeam ? selectedTeam : null}
|
||||
data={keys}
|
||||
setData={setKeys}
|
||||
premiumUser={premiumUser}
|
||||
teams={teams}
|
||||
/>
|
||||
<CreateKey
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
"use client";
|
||||
import React, { useEffect, useState } from "react";
|
||||
import { keyDeleteCall, modelAvailableCall } from "./networking";
|
||||
import { InformationCircleIcon, StatusOnlineIcon, TrashIcon, PencilAltIcon } from "@heroicons/react/outline";
|
||||
import { keySpendLogsCall, PredictedSpendLogsCall, keyUpdateCall, modelInfoCall } from "./networking";
|
||||
import { InformationCircleIcon, StatusOnlineIcon, TrashIcon, PencilAltIcon, RefreshIcon } from "@heroicons/react/outline";
|
||||
import { keySpendLogsCall, PredictedSpendLogsCall, keyUpdateCall, modelInfoCall, regenerateKeyCall } from "./networking";
|
||||
import {
|
||||
Badge,
|
||||
Card,
|
||||
Table,
|
||||
Grid,
|
||||
Col,
|
||||
Button,
|
||||
TableBody,
|
||||
TableCell,
|
||||
|
@ -33,6 +35,8 @@ import {
|
|||
Select,
|
||||
} from "antd";
|
||||
|
||||
import { CopyToClipboard } from "react-copy-to-clipboard";
|
||||
|
||||
const { Option } = Select;
|
||||
const isLocal = process.env.NODE_ENV === "development";
|
||||
const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
|
||||
|
@ -65,6 +69,7 @@ interface ViewKeyTableProps {
|
|||
data: any[] | null;
|
||||
setData: React.Dispatch<React.SetStateAction<any[] | null>>;
|
||||
teams: any[] | null;
|
||||
premiumUser: boolean;
|
||||
}
|
||||
|
||||
interface ItemData {
|
||||
|
@ -92,7 +97,8 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
selectedTeam,
|
||||
data,
|
||||
setData,
|
||||
teams
|
||||
teams,
|
||||
premiumUser
|
||||
}) => {
|
||||
const [isButtonClicked, setIsButtonClicked] = useState(false);
|
||||
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
||||
|
@ -109,6 +115,8 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
const [userModels, setUserModels] = useState([]);
|
||||
const initialKnownTeamIDs: Set<string> = new Set();
|
||||
const [modelLimitModalVisible, setModelLimitModalVisible] = useState(false);
|
||||
const [regenerateDialogVisible, setRegenerateDialogVisible] = useState(false);
|
||||
const [regeneratedKey, setRegeneratedKey] = useState<string | null>(null);
|
||||
|
||||
const [knownTeamIDs, setKnownTeamIDs] = useState(initialKnownTeamIDs);
|
||||
|
||||
|
@ -612,6 +620,38 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
setKeyToDelete(null);
|
||||
};
|
||||
|
||||
const handleRegenerateKey = async () => {
|
||||
if (!premiumUser) {
|
||||
message.error("Regenerate API Key is an Enterprise feature. Please upgrade to use this feature.");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
if (selectedToken == null) {
|
||||
message.error("Please select a key to regenerate");
|
||||
return;
|
||||
}
|
||||
const response = await regenerateKeyCall(accessToken, selectedToken.token);
|
||||
setRegeneratedKey(response.key);
|
||||
|
||||
// Update the data state with the new key_name
|
||||
if (data) {
|
||||
const updatedData = data.map(item =>
|
||||
item.token === selectedToken.token
|
||||
? { ...item, key_name: response.key_name }
|
||||
: item
|
||||
);
|
||||
setData(updatedData);
|
||||
}
|
||||
|
||||
setRegenerateDialogVisible(false);
|
||||
message.success("API Key regenerated successfully");
|
||||
} catch (error) {
|
||||
console.error("Error regenerating key:", error);
|
||||
message.error("Failed to regenerate API Key");
|
||||
}
|
||||
};
|
||||
|
||||
if (data == null) {
|
||||
return;
|
||||
}
|
||||
|
@ -768,6 +808,7 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
size="sm"
|
||||
/>
|
||||
|
||||
|
||||
|
||||
<Modal
|
||||
open={infoDialogVisible}
|
||||
|
@ -867,6 +908,14 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
size="sm"
|
||||
onClick={() => handleEditClick(item)}
|
||||
/>
|
||||
<Icon
|
||||
onClick={() => {
|
||||
setSelectedToken(item);
|
||||
setRegenerateDialogVisible(true);
|
||||
}}
|
||||
icon={RefreshIcon}
|
||||
size="sm"
|
||||
/>
|
||||
<Icon
|
||||
onClick={() => handleDelete(item)}
|
||||
icon={TrashIcon}
|
||||
|
@ -942,6 +991,98 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
accessToken={accessToken}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Regenerate Key Confirmation Dialog */}
|
||||
<Modal
|
||||
title="Regenerate API Key"
|
||||
visible={regenerateDialogVisible}
|
||||
onCancel={() => setRegenerateDialogVisible(false)}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={() => setRegenerateDialogVisible(false)} className="mr-2">
|
||||
Cancel
|
||||
</Button>,
|
||||
<Button
|
||||
key="regenerate"
|
||||
onClick={handleRegenerateKey}
|
||||
disabled={!premiumUser}
|
||||
>
|
||||
{premiumUser ? "Regenerate" : "Upgrade to Regenerate"}
|
||||
</Button>
|
||||
]}
|
||||
>
|
||||
{premiumUser ? (
|
||||
<>
|
||||
<p>Are you sure you want to regenerate this key?</p>
|
||||
<p>Key Alias:</p>
|
||||
<pre>{selectedToken?.key_alias || 'No alias set'}</pre>
|
||||
</>
|
||||
) : (
|
||||
<div>
|
||||
<p className="mb-2 text-gray-500 italic text-[12px]">Upgrade to use this feature</p>
|
||||
<Button variant="primary" className="mb-2">
|
||||
<a href="https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" target="_blank">
|
||||
Get Free Trial
|
||||
</a>
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
{/* Regenerated Key Display Modal */}
|
||||
{regeneratedKey && (
|
||||
<Modal
|
||||
visible={!!regeneratedKey}
|
||||
onCancel={() => setRegeneratedKey(null)}
|
||||
footer={[
|
||||
<Button key="close" onClick={() => setRegeneratedKey(null)}>
|
||||
Close
|
||||
</Button>
|
||||
]}
|
||||
>
|
||||
<Grid numItems={1} className="gap-2 w-full">
|
||||
<Title>Regenerated Key</Title>
|
||||
<Col numColSpan={1}>
|
||||
<p>
|
||||
Please replace your old key with the new key generated. For
|
||||
security reasons, <b>you will not be able to view it again</b> through
|
||||
your LiteLLM account. If you lose this secret key, you will need to
|
||||
generate a new one.
|
||||
</p>
|
||||
</Col>
|
||||
<Col numColSpan={1}>
|
||||
<Text className="mt-3">Key Alias:</Text>
|
||||
<div
|
||||
style={{
|
||||
background: "#f8f8f8",
|
||||
padding: "10px",
|
||||
borderRadius: "5px",
|
||||
marginBottom: "10px",
|
||||
}}
|
||||
>
|
||||
<pre style={{ wordWrap: "break-word", whiteSpace: "normal" }}>
|
||||
{selectedToken?.key_alias || 'No alias set'}
|
||||
</pre>
|
||||
</div>
|
||||
<Text className="mt-3">New API Key:</Text>
|
||||
<div
|
||||
style={{
|
||||
background: "#f8f8f8",
|
||||
padding: "10px",
|
||||
borderRadius: "5px",
|
||||
marginBottom: "10px",
|
||||
}}
|
||||
>
|
||||
<pre style={{ wordWrap: "break-word", whiteSpace: "normal" }}>
|
||||
{regeneratedKey}
|
||||
</pre>
|
||||
</div>
|
||||
<CopyToClipboard text={regeneratedKey} onCopy={() => message.success("API Key copied to clipboard")}>
|
||||
<Button className="mt-3">Copy API Key</Button>
|
||||
</CopyToClipboard>
|
||||
</Col>
|
||||
</Grid>
|
||||
</Modal>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue