fix(google_kms.py): support enums for key management system

This commit is contained in:
Krrish Dholakia 2023-12-27 13:19:19 +05:30
parent 021d7fab65
commit 85549c3d66
6 changed files with 75 additions and 16 deletions

View file

@ -7,7 +7,7 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, re
import sys, re, binascii
import litellm
import dotenv, json, traceback, threading, base64
import subprocess, os
@ -43,6 +43,7 @@ from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger
from .integrations.dynamodb import DyanmoDBLogger
from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .exceptions import (
@ -59,7 +60,7 @@ from .exceptions import (
BudgetExceededError,
UnprocessableEntityError,
)
from typing import cast, List, Dict, Union, Optional, Literal
from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -6331,24 +6332,45 @@ def litellm_telemetry(data):
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def get_secret(secret_name: str, default_value: Optional[str] = None):
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[str] = None,
):
key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if (
type(client).__module__ + "." + type(client).__name__
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = retrieved_secret = client.get_secret(secret_name).value
elif client.__class__.__name__ == "KeyManagementServiceClient":
encrypted_secret = os.getenv(secret_name)
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
if not isinstance(encrypted_secret, bytes):
# If it's not, assume it's a string and encode it to bytes
ciphertext = eval(