From 9ba520cc8b000b8910b503941e4875de6c2fa701 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 27 Dec 2023 13:19:19 +0530 Subject: [PATCH] fix(google_kms.py): support enums for key management system --- litellm/__init__.py | 2 ++ litellm/proxy/_types.py | 13 ++++++++ litellm/proxy/proxy_server.py | 16 +++++++-- litellm/proxy/secret_managers/google_kms.py | 18 +++++++---- litellm/utils.py | 36 +++++++++++++++++---- mypy.ini | 6 ++++ 6 files changed, 75 insertions(+), 16 deletions(-) create mode 100644 mypy.ini diff --git a/litellm/__init__.py b/litellm/__init__.py index 4cf303b35..f1fbbc7d3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -3,6 +3,7 @@ import threading, requests from typing import Callable, List, Optional, Dict, Union, Any from litellm.caching import Cache from litellm._logging import set_verbose +from litellm.proxy._types import KeyManagementSystem import httpx input_callback: List[Union[str, Callable]] = [] @@ -144,6 +145,7 @@ secret_manager_client: Optional[ Any ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. _google_kms_resource_name: Optional[str] = None +_key_management_system: Optional[KeyManagementSystem] = None ############################################# diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 5fe4fd44e..f98862314 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,4 +1,5 @@ from pydantic import BaseModel, Extra, Field, root_validator +import enum from typing import Optional, List, Union, Dict, Literal from datetime import datetime import uuid, json @@ -175,6 +176,12 @@ class NewUserResponse(GenerateKeyResponse): max_budget: Optional[float] = None +class KeyManagementSystem(enum.Enum): + GOOGLE_KMS = "google_kms" + AZURE_KEY_VAULT = "azure_key_vault" + LOCAL = "local" + + class ConfigGeneralSettings(LiteLLMBase): """ Documents all the fields supported by `general_settings` in config.yaml @@ -183,6 +190,12 @@ class ConfigGeneralSettings(LiteLLMBase): completion_model: Optional[str] = Field( None, description="proxy level default model for all chat completion calls" ) + key_management_system: Optional[KeyManagementSystem] = Field( + None, description="key manager to load keys from / decrypt keys with" + ) + use_google_kms: Optional[bool] = Field( + None, description="decrypt keys with google kms" + ) use_azure_key_vault: Optional[bool] = Field( None, description="load keys from azure key vault" ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c79da79cc..7864371d3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -415,6 +415,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): client = SecretClient(vault_url=KVUri, credential=credential) litellm.secret_manager_client = client + litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT else: raise Exception( f"Missing KVUri or client_id or client_secret or tenant_id from environment" @@ -691,10 +692,21 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if general_settings is None: general_settings = {} if general_settings: - ### LOAD FROM GOOGLE KMS ### + ### LOAD SECRET MANAGER ### + key_management_system = general_settings.get("key_management_system", None) + if key_management_system is not None: + if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: + ### LOAD FROM AZURE KEY VAULT ### + load_from_azure_key_vault(use_azure_key_vault=True) + elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: + ### LOAD FROM GOOGLE KMS ### + load_google_kms(use_google_kms=True) + else: + raise ValueError("Invalid Key Management System selected") + ### [DEPRECATED] LOAD FROM GOOGLE KMS ### use_google_kms = general_settings.get("use_google_kms", False) load_google_kms(use_google_kms=use_google_kms) - ### LOAD FROM AZURE KEY VAULT ### + ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) ### CONNECT TO DATABASE ### diff --git a/litellm/proxy/secret_managers/google_kms.py b/litellm/proxy/secret_managers/google_kms.py index 4901e40cd..5e83d54c7 100644 --- a/litellm/proxy/secret_managers/google_kms.py +++ b/litellm/proxy/secret_managers/google_kms.py @@ -9,6 +9,7 @@ Requires: """ import litellm, os from typing import Optional +from litellm.proxy._types import KeyManagementSystem def validate_environment(): @@ -25,12 +26,15 @@ def validate_environment(): def load_google_kms(use_google_kms: Optional[bool]): if use_google_kms is None or use_google_kms == False: return + try: + from google.cloud import kms_v1 # type: ignore - from google.cloud import kms_v1 # type: ignore + validate_environment() - validate_environment() - - # Create the KMS client - client = kms_v1.KeyManagementServiceClient() - litellm.secret_manager_client = client - litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME") + # Create the KMS client + client = kms_v1.KeyManagementServiceClient() + litellm.secret_manager_client = client + litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS + litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME") + except Exception as e: + raise e diff --git a/litellm/utils.py b/litellm/utils.py index 53483830e..e37bf10d4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7,7 +7,7 @@ # # Thank you users! We ❤️ you! - Krrish & Ishaan -import sys, re +import sys, re, binascii import litellm import dotenv, json, traceback, threading, base64 import subprocess, os @@ -43,6 +43,7 @@ from .integrations.custom_logger import CustomLogger from .integrations.langfuse import LangFuseLogger from .integrations.dynamodb import DyanmoDBLogger from .integrations.litedebugger import LiteDebugger +from .proxy._types import KeyManagementSystem from openai import OpenAIError as OriginalError from openai._models import BaseModel as OpenAIObject from .exceptions import ( @@ -59,7 +60,7 @@ from .exceptions import ( BudgetExceededError, UnprocessableEntityError, ) -from typing import cast, List, Dict, Union, Optional, Literal +from typing import cast, List, Dict, Union, Optional, Literal, Any from .caching import Cache from concurrent.futures import ThreadPoolExecutor @@ -6331,24 +6332,45 @@ def litellm_telemetry(data): ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name: str, default_value: Optional[str] = None): +def _is_base64(s): + try: + return base64.b64encode(base64.b64decode(s)).decode() == s + except binascii.Error: + return False + + +def get_secret( + secret_name: str, + default_value: Optional[str] = None, +): + key_management_system = litellm._key_management_system if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") try: if litellm.secret_manager_client is not None: try: client = litellm.secret_manager_client + key_manager = "local" + if key_management_system is not None: + key_manager = key_management_system.value if ( - type(client).__module__ + "." + type(client).__name__ + key_manager == KeyManagementSystem.AZURE_KEY_VAULT + or type(client).__module__ + "." + type(client).__name__ == "azure.keyvault.secrets._client.SecretClient" ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient - secret = retrieved_secret = client.get_secret(secret_name).value - elif client.__class__.__name__ == "KeyManagementServiceClient": - encrypted_secret = os.getenv(secret_name) + secret = client.get_secret(secret_name).value + elif ( + key_manager == KeyManagementSystem.GOOGLE_KMS + or client.__class__.__name__ == "KeyManagementServiceClient" + ): + encrypted_secret: Any = os.getenv(secret_name) if encrypted_secret is None: raise ValueError( f"Google KMS requires the encrypted secret to be in the environment!" ) + b64_flag = _is_base64(encrypted_secret) + if b64_flag == True: # if passed in as encoded b64 string + encrypted_secret = base64.b64decode(encrypted_secret) if not isinstance(encrypted_secret, bytes): # If it's not, assume it's a string and encode it to bytes ciphertext = eval( diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..faea73ef5 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +warn_return_any = False +ignore_missing_imports = False + +[mypy-google.*] +ignore_missing_imports = True