feat: working e2e credential management - support reusing existing credentials

This commit is contained in:
Krrish Dholakia 2025-03-10 19:29:24 -07:00
parent 2ec7830b66
commit f56c5ca380
5 changed files with 79 additions and 17 deletions

View file

@ -1,6 +1,11 @@
"""Utils for accessing credentials.""" """Utils for accessing credentials."""
from typing import List, Union
from pydantic import BaseModel
import litellm import litellm
from litellm.types.utils import CredentialItem
class CredentialAccessor: class CredentialAccessor:
@ -13,3 +18,17 @@ class CredentialAccessor:
if credential.credential_name == credential_name: if credential.credential_name == credential_name:
return credential.credential_values.copy() return credential.credential_values.copy()
return {} return {}
@staticmethod
def upsert_credentials(credentials: List[CredentialItem]):
"""Add a credential to the list of credentials."""
for credential in credentials:
if credential.credential_name in litellm.credential_list:
# Find and replace the existing credential in the list
for i, existing_cred in enumerate(litellm.credential_list):
if existing_cred.credential_name == credential.credential_name:
litellm.credential_list[i] = credential
break
else:
litellm.credential_list.append(credential)

View file

@ -20,6 +20,8 @@ from litellm.types.utils import CredentialItem
router = APIRouter() router = APIRouter()
class CredentialHelperUtils:
@staticmethod
def encrypt_credential_values(credential: CredentialItem) -> CredentialItem: def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
"""Encrypt values in credential.credential_values and add to DB""" """Encrypt values in credential.credential_values and add to DB"""
encrypted_credential_values = {} encrypted_credential_values = {}
@ -30,7 +32,7 @@ def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
@router.post( @router.post(
"/v1/credentials", "/credentials",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
tags=["credential management"], tags=["credential management"],
) )
@ -53,7 +55,7 @@ async def create_credential(
detail={"error": CommonProxyErrors.db_not_connected_error.value}, detail={"error": CommonProxyErrors.db_not_connected_error.value},
) )
credential = encrypt_credential_values(credential) credential = CredentialHelperUtils.encrypt_credential_values(credential)
credentials_dict = credential.model_dump() credentials_dict = credential.model_dump()
credentials_dict_jsonified = jsonify_object(credentials_dict) credentials_dict_jsonified = jsonify_object(credentials_dict)
await prisma_client.db.litellm_credentialstable.create( await prisma_client.db.litellm_credentialstable.create(
@ -71,7 +73,7 @@ async def create_credential(
@router.get( @router.get(
"/v1/credentials", "/credentials",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
tags=["credential management"], tags=["credential management"],
) )
@ -87,7 +89,7 @@ async def get_credentials(
@router.get( @router.get(
"/v1/credentials/{credential_name}", "/credentials/{credential_name}",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
tags=["credential management"], tags=["credential management"],
) )
@ -107,7 +109,7 @@ async def get_credential(
@router.delete( @router.delete(
"/v1/credentials/{credential_name}", "/credentials/{credential_name}",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
tags=["credential management"], tags=["credential management"],
) )
@ -134,7 +136,7 @@ async def delete_credential(
@router.put( @router.put(
"/v1/credentials/{credential_name}", "/credentials/{credential_name}",
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
tags=["credential management"], tags=["credential management"],
) )

View file

@ -114,6 +114,7 @@ from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs, _get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs, get_litellm_metadata_from_kwargs,
) )
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.proxy._types import * from litellm.proxy._types import *
@ -2187,7 +2188,11 @@ class ProxyConfig:
) )
## CREDENTIALS ## CREDENTIALS
litellm.credential_list = config.get("credential_list") credential_list_dict = config.get("credential_list")
if credential_list_dict:
litellm.credential_list = [
CredentialItem(**cred) for cred in credential_list_dict
]
return router, router.get_model_list(), general_settings return router, router.get_model_list(), general_settings
def _load_alerting_settings(self, general_settings: dict): def _load_alerting_settings(self, general_settings: dict):
@ -2834,6 +2839,32 @@ class ProxyConfig:
) )
) )
def decrypt_credentials(self, credential: Union[dict, BaseModel]) -> CredentialItem:
if isinstance(credential, dict):
credential_object = CredentialItem(**credential)
elif isinstance(credential, BaseModel):
credential_object = CredentialItem(**credential.model_dump())
decrypted_credential_values = {}
for k, v in credential_object.credential_values.items():
decrypted_credential_values[k] = decrypt_value_helper(v) or v
credential_object.credential_values = decrypted_credential_values
return credential_object
async def get_credentials(self, prisma_client: PrismaClient):
try:
credentials = await prisma_client.db.litellm_credentialstable.find_many()
credentials = [self.decrypt_credentials(cred) for cred in credentials]
CredentialAccessor.upsert_credentials(credentials)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - {}".format(
str(e)
)
)
return []
proxy_config = ProxyConfig() proxy_config = ProxyConfig()
@ -3255,6 +3286,14 @@ class ProxyStartupEvent:
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
) )
### GET STORED CREDENTIALS ###
scheduler.add_job(
proxy_config.get_credentials,
"interval",
seconds=10,
args=[prisma_client],
)
await proxy_config.get_credentials(prisma_client=prisma_client)
if ( if (
proxy_logging_obj is not None proxy_logging_obj is not None
and proxy_logging_obj.slack_alerting_instance.alerting is not None and proxy_logging_obj.slack_alerting_instance.alerting is not None

View file

@ -955,6 +955,7 @@ class Router:
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs, request_kwargs=kwargs,
) )
_timeout_debug_deployment_dict = deployment _timeout_debug_deployment_dict = deployment
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time

View file

@ -499,9 +499,6 @@ def function_setup( # noqa: PLR0915
## GET APPLIED GUARDRAILS ## GET APPLIED GUARDRAILS
applied_guardrails = get_applied_guardrails(kwargs) applied_guardrails = get_applied_guardrails(kwargs)
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
## LOGGING SETUP ## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
@ -1000,6 +997,8 @@ def client(original_function): # noqa: PLR0915
logging_obj, kwargs = function_setup( logging_obj, kwargs = function_setup(
original_function.__name__, rules_obj, start_time, *args, **kwargs original_function.__name__, rules_obj, start_time, *args, **kwargs
) )
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler( _llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
original_function=original_function, original_function=original_function,
@ -1256,6 +1255,8 @@ def client(original_function): # noqa: PLR0915
original_function.__name__, rules_obj, start_time, *args, **kwargs original_function.__name__, rules_obj, start_time, *args, **kwargs
) )
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
logging_obj._llm_caching_handler = _llm_caching_handler logging_obj._llm_caching_handler = _llm_caching_handler
# [OPTIONAL] CHECK BUDGET # [OPTIONAL] CHECK BUDGET
if litellm.max_budget: if litellm.max_budget: