mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat: working e2e credential management - support reusing existing credentials
This commit is contained in:
parent
2ec7830b66
commit
f56c5ca380
5 changed files with 79 additions and 17 deletions
|
@ -1,6 +1,11 @@
|
|||
"""Utils for accessing credentials."""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import CredentialItem
|
||||
|
||||
|
||||
class CredentialAccessor:
|
||||
|
@ -13,3 +18,17 @@ class CredentialAccessor:
|
|||
if credential.credential_name == credential_name:
|
||||
return credential.credential_values.copy()
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def upsert_credentials(credentials: List[CredentialItem]):
|
||||
"""Add a credential to the list of credentials."""
|
||||
|
||||
for credential in credentials:
|
||||
if credential.credential_name in litellm.credential_list:
|
||||
# Find and replace the existing credential in the list
|
||||
for i, existing_cred in enumerate(litellm.credential_list):
|
||||
if existing_cred.credential_name == credential.credential_name:
|
||||
litellm.credential_list[i] = credential
|
||||
break
|
||||
else:
|
||||
litellm.credential_list.append(credential)
|
||||
|
|
|
@ -20,7 +20,9 @@ from litellm.types.utils import CredentialItem
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
|
||||
class CredentialHelperUtils:
|
||||
@staticmethod
|
||||
def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
|
||||
"""Encrypt values in credential.credential_values and add to DB"""
|
||||
encrypted_credential_values = {}
|
||||
for key, value in credential.credential_values.items():
|
||||
|
@ -30,7 +32,7 @@ def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
|
|||
|
||||
|
||||
@router.post(
|
||||
"/v1/credentials",
|
||||
"/credentials",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["credential management"],
|
||||
)
|
||||
|
@ -53,7 +55,7 @@ async def create_credential(
|
|||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
credential = encrypt_credential_values(credential)
|
||||
credential = CredentialHelperUtils.encrypt_credential_values(credential)
|
||||
credentials_dict = credential.model_dump()
|
||||
credentials_dict_jsonified = jsonify_object(credentials_dict)
|
||||
await prisma_client.db.litellm_credentialstable.create(
|
||||
|
@ -71,7 +73,7 @@ async def create_credential(
|
|||
|
||||
|
||||
@router.get(
|
||||
"/v1/credentials",
|
||||
"/credentials",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["credential management"],
|
||||
)
|
||||
|
@ -87,7 +89,7 @@ async def get_credentials(
|
|||
|
||||
|
||||
@router.get(
|
||||
"/v1/credentials/{credential_name}",
|
||||
"/credentials/{credential_name}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["credential management"],
|
||||
)
|
||||
|
@ -107,7 +109,7 @@ async def get_credential(
|
|||
|
||||
|
||||
@router.delete(
|
||||
"/v1/credentials/{credential_name}",
|
||||
"/credentials/{credential_name}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["credential management"],
|
||||
)
|
||||
|
@ -134,7 +136,7 @@ async def delete_credential(
|
|||
|
||||
|
||||
@router.put(
|
||||
"/v1/credentials/{credential_name}",
|
||||
"/credentials/{credential_name}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["credential management"],
|
||||
)
|
||||
|
|
|
@ -114,6 +114,7 @@ from litellm.litellm_core_utils.core_helpers import (
|
|||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.proxy._types import *
|
||||
|
@ -2187,7 +2188,11 @@ class ProxyConfig:
|
|||
)
|
||||
|
||||
## CREDENTIALS
|
||||
litellm.credential_list = config.get("credential_list")
|
||||
credential_list_dict = config.get("credential_list")
|
||||
if credential_list_dict:
|
||||
litellm.credential_list = [
|
||||
CredentialItem(**cred) for cred in credential_list_dict
|
||||
]
|
||||
return router, router.get_model_list(), general_settings
|
||||
|
||||
def _load_alerting_settings(self, general_settings: dict):
|
||||
|
@ -2834,6 +2839,32 @@ class ProxyConfig:
|
|||
)
|
||||
)
|
||||
|
||||
def decrypt_credentials(self, credential: Union[dict, BaseModel]) -> CredentialItem:
|
||||
if isinstance(credential, dict):
|
||||
credential_object = CredentialItem(**credential)
|
||||
elif isinstance(credential, BaseModel):
|
||||
credential_object = CredentialItem(**credential.model_dump())
|
||||
|
||||
decrypted_credential_values = {}
|
||||
for k, v in credential_object.credential_values.items():
|
||||
decrypted_credential_values[k] = decrypt_value_helper(v) or v
|
||||
|
||||
credential_object.credential_values = decrypted_credential_values
|
||||
return credential_object
|
||||
|
||||
async def get_credentials(self, prisma_client: PrismaClient):
|
||||
try:
|
||||
credentials = await prisma_client.db.litellm_credentialstable.find_many()
|
||||
credentials = [self.decrypt_credentials(cred) for cred in credentials]
|
||||
CredentialAccessor.upsert_credentials(credentials)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
|
||||
|
@ -3255,6 +3286,14 @@ class ProxyStartupEvent:
|
|||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
### GET STORED CREDENTIALS ###
|
||||
scheduler.add_job(
|
||||
proxy_config.get_credentials,
|
||||
"interval",
|
||||
seconds=10,
|
||||
args=[prisma_client],
|
||||
)
|
||||
await proxy_config.get_credentials(prisma_client=prisma_client)
|
||||
if (
|
||||
proxy_logging_obj is not None
|
||||
and proxy_logging_obj.slack_alerting_instance.alerting is not None
|
||||
|
|
|
@ -955,6 +955,7 @@ class Router:
|
|||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
|
||||
_timeout_debug_deployment_dict = deployment
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
|
|
|
@ -499,9 +499,6 @@ def function_setup( # noqa: PLR0915
|
|||
## GET APPLIED GUARDRAILS
|
||||
applied_guardrails = get_applied_guardrails(kwargs)
|
||||
|
||||
## LOAD CREDENTIALS
|
||||
load_credentials_from_list(kwargs)
|
||||
|
||||
## LOGGING SETUP
|
||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
|
@ -1000,6 +997,8 @@ def client(original_function): # noqa: PLR0915
|
|||
logging_obj, kwargs = function_setup(
|
||||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
## LOAD CREDENTIALS
|
||||
load_credentials_from_list(kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
|
||||
original_function=original_function,
|
||||
|
@ -1256,6 +1255,8 @@ def client(original_function): # noqa: PLR0915
|
|||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||
)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
## LOAD CREDENTIALS
|
||||
load_credentials_from_list(kwargs)
|
||||
logging_obj._llm_caching_handler = _llm_caching_handler
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
if litellm.max_budget:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue