mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat: working e2e credential management - support reusing existing credentials
This commit is contained in:
parent
2ec7830b66
commit
f56c5ca380
5 changed files with 79 additions and 17 deletions
|
@ -1,6 +1,11 @@
|
||||||
"""Utils for accessing credentials."""
|
"""Utils for accessing credentials."""
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.types.utils import CredentialItem
|
||||||
|
|
||||||
|
|
||||||
class CredentialAccessor:
|
class CredentialAccessor:
|
||||||
|
@ -13,3 +18,17 @@ class CredentialAccessor:
|
||||||
if credential.credential_name == credential_name:
|
if credential.credential_name == credential_name:
|
||||||
return credential.credential_values.copy()
|
return credential.credential_values.copy()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def upsert_credentials(credentials: List[CredentialItem]):
|
||||||
|
"""Add a credential to the list of credentials."""
|
||||||
|
|
||||||
|
for credential in credentials:
|
||||||
|
if credential.credential_name in litellm.credential_list:
|
||||||
|
# Find and replace the existing credential in the list
|
||||||
|
for i, existing_cred in enumerate(litellm.credential_list):
|
||||||
|
if existing_cred.credential_name == credential.credential_name:
|
||||||
|
litellm.credential_list[i] = credential
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
litellm.credential_list.append(credential)
|
||||||
|
|
|
@ -20,17 +20,19 @@ from litellm.types.utils import CredentialItem
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
|
class CredentialHelperUtils:
|
||||||
"""Encrypt values in credential.credential_values and add to DB"""
|
@staticmethod
|
||||||
encrypted_credential_values = {}
|
def encrypt_credential_values(credential: CredentialItem) -> CredentialItem:
|
||||||
for key, value in credential.credential_values.items():
|
"""Encrypt values in credential.credential_values and add to DB"""
|
||||||
encrypted_credential_values[key] = encrypt_value_helper(value)
|
encrypted_credential_values = {}
|
||||||
credential.credential_values = encrypted_credential_values
|
for key, value in credential.credential_values.items():
|
||||||
return credential
|
encrypted_credential_values[key] = encrypt_value_helper(value)
|
||||||
|
credential.credential_values = encrypted_credential_values
|
||||||
|
return credential
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/credentials",
|
"/credentials",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
tags=["credential management"],
|
tags=["credential management"],
|
||||||
)
|
)
|
||||||
|
@ -53,7 +55,7 @@ async def create_credential(
|
||||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||||
)
|
)
|
||||||
|
|
||||||
credential = encrypt_credential_values(credential)
|
credential = CredentialHelperUtils.encrypt_credential_values(credential)
|
||||||
credentials_dict = credential.model_dump()
|
credentials_dict = credential.model_dump()
|
||||||
credentials_dict_jsonified = jsonify_object(credentials_dict)
|
credentials_dict_jsonified = jsonify_object(credentials_dict)
|
||||||
await prisma_client.db.litellm_credentialstable.create(
|
await prisma_client.db.litellm_credentialstable.create(
|
||||||
|
@ -71,7 +73,7 @@ async def create_credential(
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/v1/credentials",
|
"/credentials",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
tags=["credential management"],
|
tags=["credential management"],
|
||||||
)
|
)
|
||||||
|
@ -87,7 +89,7 @@ async def get_credentials(
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/v1/credentials/{credential_name}",
|
"/credentials/{credential_name}",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
tags=["credential management"],
|
tags=["credential management"],
|
||||||
)
|
)
|
||||||
|
@ -107,7 +109,7 @@ async def get_credential(
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/v1/credentials/{credential_name}",
|
"/credentials/{credential_name}",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
tags=["credential management"],
|
tags=["credential management"],
|
||||||
)
|
)
|
||||||
|
@ -134,7 +136,7 @@ async def delete_credential(
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/v1/credentials/{credential_name}",
|
"/credentials/{credential_name}",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
tags=["credential management"],
|
tags=["credential management"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -114,6 +114,7 @@ from litellm.litellm_core_utils.core_helpers import (
|
||||||
_get_parent_otel_span_from_kwargs,
|
_get_parent_otel_span_from_kwargs,
|
||||||
get_litellm_metadata_from_kwargs,
|
get_litellm_metadata_from_kwargs,
|
||||||
)
|
)
|
||||||
|
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
@ -2187,7 +2188,11 @@ class ProxyConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
## CREDENTIALS
|
## CREDENTIALS
|
||||||
litellm.credential_list = config.get("credential_list")
|
credential_list_dict = config.get("credential_list")
|
||||||
|
if credential_list_dict:
|
||||||
|
litellm.credential_list = [
|
||||||
|
CredentialItem(**cred) for cred in credential_list_dict
|
||||||
|
]
|
||||||
return router, router.get_model_list(), general_settings
|
return router, router.get_model_list(), general_settings
|
||||||
|
|
||||||
def _load_alerting_settings(self, general_settings: dict):
|
def _load_alerting_settings(self, general_settings: dict):
|
||||||
|
@ -2834,6 +2839,32 @@ class ProxyConfig:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def decrypt_credentials(self, credential: Union[dict, BaseModel]) -> CredentialItem:
|
||||||
|
if isinstance(credential, dict):
|
||||||
|
credential_object = CredentialItem(**credential)
|
||||||
|
elif isinstance(credential, BaseModel):
|
||||||
|
credential_object = CredentialItem(**credential.model_dump())
|
||||||
|
|
||||||
|
decrypted_credential_values = {}
|
||||||
|
for k, v in credential_object.credential_values.items():
|
||||||
|
decrypted_credential_values[k] = decrypt_value_helper(v) or v
|
||||||
|
|
||||||
|
credential_object.credential_values = decrypted_credential_values
|
||||||
|
return credential_object
|
||||||
|
|
||||||
|
async def get_credentials(self, prisma_client: PrismaClient):
|
||||||
|
try:
|
||||||
|
credentials = await prisma_client.db.litellm_credentialstable.find_many()
|
||||||
|
credentials = [self.decrypt_credentials(cred) for cred in credentials]
|
||||||
|
CredentialAccessor.upsert_credentials(credentials)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
|
|
||||||
|
@ -3255,6 +3286,14 @@ class ProxyStartupEvent:
|
||||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### GET STORED CREDENTIALS ###
|
||||||
|
scheduler.add_job(
|
||||||
|
proxy_config.get_credentials,
|
||||||
|
"interval",
|
||||||
|
seconds=10,
|
||||||
|
args=[prisma_client],
|
||||||
|
)
|
||||||
|
await proxy_config.get_credentials(prisma_client=prisma_client)
|
||||||
if (
|
if (
|
||||||
proxy_logging_obj is not None
|
proxy_logging_obj is not None
|
||||||
and proxy_logging_obj.slack_alerting_instance.alerting is not None
|
and proxy_logging_obj.slack_alerting_instance.alerting is not None
|
||||||
|
|
|
@ -955,6 +955,7 @@ class Router:
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
request_kwargs=kwargs,
|
request_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
_timeout_debug_deployment_dict = deployment
|
_timeout_debug_deployment_dict = deployment
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
|
|
|
@ -499,9 +499,6 @@ def function_setup( # noqa: PLR0915
|
||||||
## GET APPLIED GUARDRAILS
|
## GET APPLIED GUARDRAILS
|
||||||
applied_guardrails = get_applied_guardrails(kwargs)
|
applied_guardrails = get_applied_guardrails(kwargs)
|
||||||
|
|
||||||
## LOAD CREDENTIALS
|
|
||||||
load_credentials_from_list(kwargs)
|
|
||||||
|
|
||||||
## LOGGING SETUP
|
## LOGGING SETUP
|
||||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||||
|
|
||||||
|
@ -1000,6 +997,8 @@ def client(original_function): # noqa: PLR0915
|
||||||
logging_obj, kwargs = function_setup(
|
logging_obj, kwargs = function_setup(
|
||||||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
## LOAD CREDENTIALS
|
||||||
|
load_credentials_from_list(kwargs)
|
||||||
kwargs["litellm_logging_obj"] = logging_obj
|
kwargs["litellm_logging_obj"] = logging_obj
|
||||||
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
|
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
|
@ -1256,6 +1255,8 @@ def client(original_function): # noqa: PLR0915
|
||||||
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
original_function.__name__, rules_obj, start_time, *args, **kwargs
|
||||||
)
|
)
|
||||||
kwargs["litellm_logging_obj"] = logging_obj
|
kwargs["litellm_logging_obj"] = logging_obj
|
||||||
|
## LOAD CREDENTIALS
|
||||||
|
load_credentials_from_list(kwargs)
|
||||||
logging_obj._llm_caching_handler = _llm_caching_handler
|
logging_obj._llm_caching_handler = _llm_caching_handler
|
||||||
# [OPTIONAL] CHECK BUDGET
|
# [OPTIONAL] CHECK BUDGET
|
||||||
if litellm.max_budget:
|
if litellm.max_budget:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue