feat(endpoints.py): support adding credentials by model id

Allows user to reuse existing model credentials
This commit is contained in:
Krrish Dholakia 2025-03-14 12:32:32 -07:00
parent 605a4d1121
commit f089b1e23f
5 changed files with 99 additions and 47 deletions

View file

@ -2,7 +2,7 @@
CRUD endpoints for storing reusable credentials.
"""
from typing import Optional
from typing import Optional, Union
from fastapi import APIRouter, Depends, HTTPException, Request, Response
@ -14,8 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object
from litellm.types.router import CredentialLiteLLMParams
from litellm.types.utils import CredentialItem
from litellm.types.utils import CreateCredentialItem, CredentialItem
router = APIRouter()
@ -39,7 +38,7 @@ class CredentialHelperUtils:
async def create_credential(
request: Request,
fastapi_response: Response,
credential: CredentialItem,
credential: CreateCredentialItem,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@ -47,7 +46,7 @@ async def create_credential(
Stores credential in DB.
Reloads credentials in memory.
"""
from litellm.proxy.proxy_server import prisma_client
from litellm.proxy.proxy_server import llm_router, prisma_client
try:
if prisma_client is None:
@ -55,9 +54,35 @@ async def create_credential(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if credential.model_id:
if llm_router is None:
raise HTTPException(
status_code=500,
detail="LLM router not found. Please ensure you have a valid router instance.",
)
# get model from router
model = llm_router.get_deployment(credential.model_id)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
credential_values = llm_router.get_deployment_credentials(
credential.model_id
)
if credential_values is None:
raise HTTPException(status_code=404, detail="Model not found")
credential.credential_values = credential_values
if credential.credential_values is None:
raise HTTPException(
status_code=400,
detail="Credential values are required. Unable to infer credential values from model ID.",
)
processed_credential = CredentialItem(
credential_name=credential.credential_name,
credential_values=credential.credential_values,
credential_info=credential.credential_info,
)
encrypted_credential = CredentialHelperUtils.encrypt_credential_values(
credential
processed_credential
)
credentials_dict = encrypted_credential.model_dump()
credentials_dict_jsonified = jsonify_object(credentials_dict)
@ -70,7 +95,7 @@ async def create_credential(
)
## ADD TO LITELLM ##
CredentialAccessor.upsert_credentials([credential])
CredentialAccessor.upsert_credentials([processed_credential])
return {"success": True, "message": "Credential created successfully"}
except Exception as e:
@ -95,6 +120,7 @@ async def get_credentials(
masked_credentials = [
{
"credential_name": credential.credential_name,
"credential_values": _get_masked_values(credential.credential_values),
"credential_info": credential.credential_info,
}
for credential in litellm.credential_list
@ -132,23 +158,20 @@ async def get_credential(
if model_id:
if llm_router is None:
raise HTTPException(status_code=500, detail="LLM router not found")
# get model from router
model = llm_router.get_deployment(model_id)
if model is None:
raise HTTPException(
status_code=404, detail="Model not found. Got model ID: " + model_id
)
# get credential object from model
credential_values = _get_masked_values(
CredentialLiteLLMParams(**model.litellm_params.model_dump()).model_dump(
exclude_none=True
),
raise HTTPException(status_code=404, detail="Model not found")
credential_values = llm_router.get_deployment_credentials(model_id)
if credential_values is None:
raise HTTPException(status_code=404, detail="Model not found")
masked_credential_values = _get_masked_values(
credential_values,
unmasked_length=4,
number_of_asterisks=4,
)
credential = CredentialItem(
credential_name="{}-credential-{}".format(model.model_name, model_id),
credential_values=credential_values,
credential_values=masked_credential_values,
credential_info={},
)
# return credential object
@ -159,7 +182,9 @@ async def get_credential(
masked_credential = CredentialItem(
credential_name=credential.credential_name,
credential_values=_get_masked_values(
credential.credential_values
credential.credential_values,
unmasked_length=4,
number_of_asterisks=4,
),
credential_info=credential.credential_info,
)