feat: refactor add models tab on UI to enable setting credentials

This commit is contained in:
Krrish Dholakia 2025-03-12 20:32:01 -07:00
parent 52926408cd
commit d604f52884
6 changed files with 192 additions and 40 deletions

View file

@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Response
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
@ -51,8 +52,10 @@ async def create_credential(
detail={"error": CommonProxyErrors.db_not_connected_error.value}, detail={"error": CommonProxyErrors.db_not_connected_error.value},
) )
credential = CredentialHelperUtils.encrypt_credential_values(credential) encrypted_credential = CredentialHelperUtils.encrypt_credential_values(
credentials_dict = credential.model_dump() credential
)
credentials_dict = encrypted_credential.model_dump()
credentials_dict_jsonified = jsonify_object(credentials_dict) credentials_dict_jsonified = jsonify_object(credentials_dict)
await prisma_client.db.litellm_credentialstable.create( await prisma_client.db.litellm_credentialstable.create(
data={ data={
@ -62,6 +65,9 @@ async def create_credential(
} }
) )
## ADD TO LITELLM ##
CredentialAccessor.upsert_credentials([credential])
return {"success": True, "message": "Credential created successfully"} return {"success": True, "message": "Credential created successfully"}
except Exception as e: except Exception as e:
verbose_proxy_logger.exception(e) verbose_proxy_logger.exception(e)
@ -146,6 +152,13 @@ async def delete_credential(
await prisma_client.db.litellm_credentialstable.delete( await prisma_client.db.litellm_credentialstable.delete(
where={"credential_name": credential_name} where={"credential_name": credential_name}
) )
## DELETE FROM LITELLM ##
litellm.credential_list = [
cred
for cred in litellm.credential_list
if cred.credential_name != credential_name
]
return {"success": True, "message": "Credential deleted successfully"} return {"success": True, "message": "Credential deleted successfully"}
except Exception as e: except Exception as e:
return handle_exception_on_proxy(e) return handle_exception_on_proxy(e)

View file

@ -1725,6 +1725,16 @@ class ProxyConfig:
) )
return {} return {}
def load_credential_list(self, config: dict) -> List[CredentialItem]:
"""
Load the credential list from the database
"""
credential_list_dict = config.get("credential_list")
credential_list = []
if credential_list_dict:
credential_list = [CredentialItem(**cred) for cred in credential_list_dict]
return credential_list
async def load_config( # noqa: PLR0915 async def load_config( # noqa: PLR0915
self, router: Optional[litellm.Router], config_file_path: str self, router: Optional[litellm.Router], config_file_path: str
): ):
@ -2190,11 +2200,8 @@ class ProxyConfig:
) )
## CREDENTIALS ## CREDENTIALS
credential_list_dict = config.get("credential_list") credential_list_dict = self.load_credential_list(config=config)
if credential_list_dict: litellm.credential_list = credential_list_dict
litellm.credential_list = [
CredentialItem(**cred) for cred in credential_list_dict
]
return router, router.get_model_list(), general_settings return router, router.get_model_list(), general_settings
def _load_alerting_settings(self, general_settings: dict): def _load_alerting_settings(self, general_settings: dict):
@ -2854,11 +2861,39 @@ class ProxyConfig:
credential_object.credential_values = decrypted_credential_values credential_object.credential_values = decrypted_credential_values
return credential_object return credential_object
async def delete_credentials(self, db_credentials: List[CredentialItem]):
"""
Create all-up list of db credentials + local credentials
Compare to the litellm.credential_list
Delete any from litellm.credential_list that are not in the all-up list
"""
## CONFIG credentials ##
config = await self.get_config(config_file_path=user_config_file_path)
credential_list = self.load_credential_list(config=config)
## COMBINED LIST ##
combined_list = db_credentials + credential_list
## DELETE ##
idx_to_delete = []
for idx, credential in enumerate(litellm.credential_list):
if credential.credential_name not in [
cred.credential_name for cred in combined_list
]:
idx_to_delete.append(idx)
for idx in sorted(idx_to_delete, reverse=True):
litellm.credential_list.pop(idx)
async def get_credentials(self, prisma_client: PrismaClient): async def get_credentials(self, prisma_client: PrismaClient):
try: try:
credentials = await prisma_client.db.litellm_credentialstable.find_many() credentials = await prisma_client.db.litellm_credentialstable.find_many()
credentials = [self.decrypt_credentials(cred) for cred in credentials] credentials = [self.decrypt_credentials(cred) for cred in credentials]
CredentialAccessor.upsert_credentials(credentials) await self.delete_credentials(
credentials
) # delete credentials that are not in the all-up list
CredentialAccessor.upsert_credentials(
credentials
) # upsert credentials that are in the all-up list
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.exception(
"litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - {}".format( "litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - {}".format(

View file

@ -8,7 +8,7 @@ import ProviderSpecificFields from "./provider_specific_fields";
import AdvancedSettings from "./advanced_settings"; import AdvancedSettings from "./advanced_settings";
import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers"; import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers";
import type { Team } from "../key_team_helpers/key_list"; import type { Team } from "../key_team_helpers/key_list";
import { CredentialItem } from "../networking";
interface AddModelTabProps { interface AddModelTabProps {
form: FormInstance; form: FormInstance;
handleOk: () => void; handleOk: () => void;
@ -21,6 +21,7 @@ interface AddModelTabProps {
showAdvancedSettings: boolean; showAdvancedSettings: boolean;
setShowAdvancedSettings: (show: boolean) => void; setShowAdvancedSettings: (show: boolean) => void;
teams: Team[] | null; teams: Team[] | null;
credentials: CredentialItem[];
} }
const { Title, Link } = Typography; const { Title, Link } = Typography;
@ -37,6 +38,7 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
showAdvancedSettings, showAdvancedSettings,
setShowAdvancedSettings, setShowAdvancedSettings,
teams, teams,
credentials,
}) => { }) => {
return ( return (
<> <>
@ -108,6 +110,38 @@ const AddModelTab: React.FC<AddModelTabProps> = ({
{/* Conditionally Render "Public Model Name" */} {/* Conditionally Render "Public Model Name" */}
<ConditionalPublicModelName /> <ConditionalPublicModelName />
{/* Credentials */}
<div className="mb-4">
<Typography.Text className="text-sm text-gray-500 mb-2">
Either select existing credentials OR enter new provider credentials below
</Typography.Text>
</div>
<Form.Item
label="Existing Credentials"
name="credential_name"
>
<AntdSelect
showSearch
placeholder="Select or search for existing credentials"
optionFilterProp="children"
filterOption={(input, option) =>
(option?.label ?? '').toLowerCase().includes(input.toLowerCase())
}
options={credentials.map((credential) => ({
value: credential.credential_name,
label: credential.credential_name
}))}
allowClear
/>
</Form.Item>
<div className="flex items-center my-4">
<div className="flex-grow border-t border-gray-200"></div>
<span className="px-4 text-gray-500 text-sm">OR</span>
<div className="flex-grow border-t border-gray-200"></div>
</div>
<ProviderSpecificFields <ProviderSpecificFields
selectedProvider={selectedProvider} selectedProvider={selectedProvider}
uploadProps={uploadProps} uploadProps={uploadProps}

View file

@ -11,32 +11,29 @@ import {
Badge, Badge,
Button Button
} from "@tremor/react"; } from "@tremor/react";
import {
InformationCircleIcon,
PencilAltIcon,
PencilIcon,
RefreshIcon,
StatusOnlineIcon,
TrashIcon,
} from "@heroicons/react/outline";
import { UploadProps } from "antd/es/upload"; import { UploadProps } from "antd/es/upload";
import { PlusIcon } from "@heroicons/react/solid"; import { PlusIcon } from "@heroicons/react/solid";
import { credentialListCall, credentialCreateCall } from "@/components/networking"; // Assume this is your networking function import { credentialListCall, credentialCreateCall, credentialDeleteCall, CredentialItem, CredentialsResponse } from "@/components/networking"; // Assume this is your networking function
import AddCredentialsTab from "./add_credentials_tab"; import AddCredentialsTab from "./add_credentials_tab";
import { Form, message } from "antd"; import { Form, message } from "antd";
interface CredentialsPanelProps { interface CredentialsPanelProps {
accessToken: string | null; accessToken: string | null;
uploadProps: UploadProps; uploadProps: UploadProps;
credentialList: CredentialItem[];
fetchCredentials: (accessToken: string) => Promise<void>;
} }
interface CredentialsResponse {
credentials: CredentialItem[];
}
interface CredentialItem {
credential_name: string | null;
credential_values: object;
credential_info: {
custom_llm_provider?: string;
description?: string;
required?: boolean;
};
}
const CredentialsPanel: React.FC<CredentialsPanelProps> = ({ accessToken, uploadProps }) => { const CredentialsPanel: React.FC<CredentialsPanelProps> = ({ accessToken, uploadProps, credentialList, fetchCredentials }) => {
const [credentialsList, setCredentialsList] = useState<CredentialItem[]>([]);
const [isAddModalOpen, setIsAddModalOpen] = useState(false); const [isAddModalOpen, setIsAddModalOpen] = useState(false);
const [form] = Form.useForm(); const [form] = Form.useForm();
@ -63,24 +60,16 @@ const CredentialsPanel: React.FC<CredentialsPanelProps> = ({ accessToken, upload
message.success('Credential added successfully'); message.success('Credential added successfully');
console.log(`response: ${JSON.stringify(response)}`); console.log(`response: ${JSON.stringify(response)}`);
setIsAddModalOpen(false); setIsAddModalOpen(false);
form.resetFields(); fetchCredentials(accessToken);
}; };
useEffect(() => { useEffect(() => {
if (!accessToken) { if (!accessToken) {
return; return;
} }
fetchCredentials(accessToken);
const fetchCredentials = async () => {
try {
const response: CredentialsResponse = await credentialListCall(accessToken);
console.log(`credentials: ${JSON.stringify(response)}`);
setCredentialsList(response.credentials);
} catch (error) {
console.error('Error fetching credentials:', error);
}
};
fetchCredentials();
}, [accessToken]); }, [accessToken]);
const renderProviderBadge = (provider: string) => { const renderProviderBadge = (provider: string) => {
@ -99,6 +88,18 @@ const CredentialsPanel: React.FC<CredentialsPanelProps> = ({ accessToken, upload
); );
}; };
const handleDeleteCredential = async (credentialName: string) => {
if (!accessToken) {
console.error('No access token found');
return;
}
const response = await credentialDeleteCall(accessToken, credentialName);
console.log(`response: ${JSON.stringify(response)}`);
message.success('Credential deleted successfully');
fetchCredentials(accessToken);
};
return ( return (
<div className="w-full mx-auto flex-auto overflow-y-auto m-8 p-2"> <div className="w-full mx-auto flex-auto overflow-y-auto m-8 p-2">
<div className="flex justify-between items-center mb-4"> <div className="flex justify-between items-center mb-4">
@ -125,20 +126,33 @@ const CredentialsPanel: React.FC<CredentialsPanelProps> = ({ accessToken, upload
</TableRow> </TableRow>
</TableHead> </TableHead>
<TableBody> <TableBody>
{(!credentialsList || credentialsList.length === 0) ? ( {(!credentialList || credentialList.length === 0) ? (
<TableRow> <TableRow>
<TableCell colSpan={4} className="text-center py-4 text-gray-500"> <TableCell colSpan={4} className="text-center py-4 text-gray-500">
No credentials configured No credentials configured
</TableCell> </TableCell>
</TableRow> </TableRow>
) : ( ) : (
credentialsList.map((credential: CredentialItem, index: number) => ( credentialList.map((credential: CredentialItem, index: number) => (
<TableRow key={index}> <TableRow key={index}>
<TableCell>{credential.credential_name}</TableCell> <TableCell>{credential.credential_name}</TableCell>
<TableCell> <TableCell>
{renderProviderBadge(credential.credential_info?.custom_llm_provider as string || '-')} {renderProviderBadge(credential.credential_info?.custom_llm_provider as string || '-')}
</TableCell> </TableCell>
<TableCell>{credential.credential_info?.description || '-'}</TableCell> <TableCell>{credential.credential_info?.description || '-'}</TableCell>
<TableCell>
<Button
icon={PencilAltIcon}
variant="light"
size="sm"
/>
<Button
icon={TrashIcon}
variant="light"
size="sm"
onClick={() => handleDeleteCredential(credential.credential_name)}
/>
</TableCell>
</TableRow> </TableRow>
)) ))
)} )}

View file

@ -16,7 +16,7 @@ import {
AccordionHeader, AccordionHeader,
AccordionBody, AccordionBody,
} from "@tremor/react"; } from "@tremor/react";
import { CredentialItem, credentialListCall, CredentialsResponse } from "./networking";
import ConditionalPublicModelName from "./add_model/conditional_public_model_name"; import ConditionalPublicModelName from "./add_model/conditional_public_model_name";
import LiteLLMModelNameField from "./add_model/litellm_model_name"; import LiteLLMModelNameField from "./add_model/litellm_model_name";
@ -235,6 +235,8 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
const [allEndUsers, setAllEndUsers] = useState<any[]>([]); const [allEndUsers, setAllEndUsers] = useState<any[]>([]);
const [credentialsList, setCredentialsList] = useState<CredentialItem[]>([]);
// Add state for advanced settings visibility // Add state for advanced settings visibility
const [showAdvancedSettings, setShowAdvancedSettings] = useState<boolean>(false); const [showAdvancedSettings, setShowAdvancedSettings] = useState<boolean>(false);
@ -374,6 +376,16 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
} }
}; };
const fetchCredentials = async (accessToken: string) => {
try {
const response: CredentialsResponse = await credentialListCall(accessToken);
console.log(`credentials: ${JSON.stringify(response)}`);
setCredentialsList(response.credentials);
} catch (error) {
console.error('Error fetching credentials:', error);
}
};
useEffect(() => { useEffect(() => {
updateModelMetrics( updateModelMetrics(
@ -1126,6 +1138,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
showAdvancedSettings={showAdvancedSettings} showAdvancedSettings={showAdvancedSettings}
setShowAdvancedSettings={setShowAdvancedSettings} setShowAdvancedSettings={setShowAdvancedSettings}
teams={teams} teams={teams}
credentials={credentialsList}
/> />
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
@ -1484,7 +1497,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
</Button> </Button>
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
<CredentialsPanel accessToken={accessToken} uploadProps={uploadProps}/> <CredentialsPanel accessToken={accessToken} uploadProps={uploadProps} credentialList={credentialsList} fetchCredentials={fetchCredentials} />
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>
</TabGroup> </TabGroup>

View file

@ -36,6 +36,21 @@ export interface Organization {
members: any[] | null; members: any[] | null;
} }
export interface CredentialItem {
credential_name: string;
credential_values: object;
credential_info: {
custom_llm_provider?: string;
description?: string;
required?: boolean;
};
}
export interface CredentialsResponse {
credentials: CredentialItem[];
}
const baseUrl = "/"; // Assuming the base URL is the root const baseUrl = "/"; // Assuming the base URL is the root
@ -2606,6 +2621,34 @@ export const credentialListCall = async (
} }
}; };
export const credentialDeleteCall = async (accessToken: String, credentialName: String) => {
try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/credentials/${credentialName}` : `/credentials/${credentialName}`;
console.log("in credentialDeleteCall:", credentialName);
const response = await fetch(url, {
method: "DELETE",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log(data);
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to delete key:", error);
throw error;
}
};
export const keyUpdateCall = async ( export const keyUpdateCall = async (
accessToken: string, accessToken: string,
formValues: Record<string, any> // Assuming formValues is an object formValues: Record<string, any> // Assuming formValues is an object