fix(google_kms.py): support enums for key management system

This commit is contained in:
Krrish Dholakia 2023-12-27 13:19:19 +05:30
parent 4cc59d21d0
commit 9ba520cc8b
6 changed files with 75 additions and 16 deletions

View file

@ -3,6 +3,7 @@ import threading, requests
from typing import Callable, List, Optional, Dict, Union, Any from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import set_verbose from litellm._logging import set_verbose
from litellm.proxy._types import KeyManagementSystem
import httpx import httpx
input_callback: List[Union[str, Callable]] = [] input_callback: List[Union[str, Callable]] = []
@ -144,6 +145,7 @@ secret_manager_client: Optional[
Any Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None _google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
############################################# #############################################

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
import enum
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
@ -175,6 +176,12 @@ class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None max_budget: Optional[float] = None
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
LOCAL = "local"
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
@ -183,6 +190,12 @@ class ConfigGeneralSettings(LiteLLMBase):
completion_model: Optional[str] = Field( completion_model: Optional[str] = Field(
None, description="proxy level default model for all chat completion calls" None, description="proxy level default model for all chat completion calls"
) )
key_management_system: Optional[KeyManagementSystem] = Field(
None, description="key manager to load keys from / decrypt keys with"
)
use_google_kms: Optional[bool] = Field(
None, description="decrypt keys with google kms"
)
use_azure_key_vault: Optional[bool] = Field( use_azure_key_vault: Optional[bool] = Field(
None, description="load keys from azure key vault" None, description="load keys from azure key vault"
) )

View file

@ -415,6 +415,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
client = SecretClient(vault_url=KVUri, credential=credential) client = SecretClient(vault_url=KVUri, credential=credential)
litellm.secret_manager_client = client litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
else: else:
raise Exception( raise Exception(
f"Missing KVUri or client_id or client_secret or tenant_id from environment" f"Missing KVUri or client_id or client_secret or tenant_id from environment"
@ -691,10 +692,21 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings is None: if general_settings is None:
general_settings = {} general_settings = {}
if general_settings: if general_settings:
### LOAD FROM GOOGLE KMS ### ### LOAD SECRET MANAGER ###
key_management_system = general_settings.get("key_management_system", None)
if key_management_system is not None:
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ###
use_google_kms = general_settings.get("use_google_kms", False) use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms) load_google_kms(use_google_kms=use_google_kms)
### LOAD FROM AZURE KEY VAULT ### ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False) use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###

View file

@ -9,6 +9,7 @@ Requires:
""" """
import litellm, os import litellm, os
from typing import Optional from typing import Optional
from litellm.proxy._types import KeyManagementSystem
def validate_environment(): def validate_environment():
@ -25,12 +26,15 @@ def validate_environment():
def load_google_kms(use_google_kms: Optional[bool]): def load_google_kms(use_google_kms: Optional[bool]):
if use_google_kms is None or use_google_kms == False: if use_google_kms is None or use_google_kms == False:
return return
try:
from google.cloud import kms_v1 # type: ignore
from google.cloud import kms_v1 # type: ignore validate_environment()
validate_environment() # Create the KMS client
client = kms_v1.KeyManagementServiceClient()
# Create the KMS client litellm.secret_manager_client = client
client = kms_v1.KeyManagementServiceClient() litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
litellm.secret_manager_client = client litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME") except Exception as e:
raise e

View file

@ -7,7 +7,7 @@
# #
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, re import sys, re, binascii
import litellm import litellm
import dotenv, json, traceback, threading, base64 import dotenv, json, traceback, threading, base64
import subprocess, os import subprocess, os
@ -43,6 +43,7 @@ from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger from .integrations.langfuse import LangFuseLogger
from .integrations.dynamodb import DyanmoDBLogger from .integrations.dynamodb import DyanmoDBLogger
from .integrations.litedebugger import LiteDebugger from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from .exceptions import ( from .exceptions import (
@ -59,7 +60,7 @@ from .exceptions import (
BudgetExceededError, BudgetExceededError,
UnprocessableEntityError, UnprocessableEntityError,
) )
from typing import cast, List, Dict, Union, Optional, Literal from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache from .caching import Cache
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -6331,24 +6332,45 @@ def litellm_telemetry(data):
######### Secret Manager ############################ ######### Secret Manager ############################
# checks if user has passed in a secret manager client # checks if user has passed in a secret manager client
# if passed in then checks the secret there # if passed in then checks the secret there
def get_secret(secret_name: str, default_value: Optional[str] = None): def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[str] = None,
):
key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "") secret_name = secret_name.replace("os.environ/", "")
try: try:
if litellm.secret_manager_client is not None: if litellm.secret_manager_client is not None:
try: try:
client = litellm.secret_manager_client client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if ( if (
type(client).__module__ + "." + type(client).__name__ key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient" == "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = retrieved_secret = client.get_secret(secret_name).value secret = client.get_secret(secret_name).value
elif client.__class__.__name__ == "KeyManagementServiceClient": elif (
encrypted_secret = os.getenv(secret_name) key_manager == KeyManagementSystem.GOOGLE_KMS
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None: if encrypted_secret is None:
raise ValueError( raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!" f"Google KMS requires the encrypted secret to be in the environment!"
) )
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
if not isinstance(encrypted_secret, bytes): if not isinstance(encrypted_secret, bytes):
# If it's not, assume it's a string and encode it to bytes # If it's not, assume it's a string and encode it to bytes
ciphertext = eval( ciphertext = eval(

6
mypy.ini Normal file
View file

@ -0,0 +1,6 @@
[mypy]
warn_return_any = False
ignore_missing_imports = False
[mypy-google.*]
ignore_missing_imports = True