diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 72479bd5d..7fa1bbc19 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -56,8 +56,10 @@ router_settings: litellm_settings: success_callback: ["langfuse"] - json_logs: true general_settings: alerting: ["email"] + key_management_system: "aws_kms" + key_management_settings: + hosted_keys: ["LITELLM_MASTER_KEY"] diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8a95f4e1d..c19fd7e69 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum): AZURE_KEY_VAULT = "azure_key_vault" AWS_SECRET_MANAGER = "aws_secret_manager" LOCAL = "local" + AWS_KMS = "aws_kms" class KeyManagementSettings(LiteLLMBase): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8cf2fa118..e188841eb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -112,7 +112,10 @@ from litellm import ( CreateFileRequest, ) from litellm.proxy.secret_managers.google_kms import load_google_kms -from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager +from litellm.proxy.secret_managers.aws_secret_manager import ( + load_aws_secret_manager, + load_aws_kms, +) import pydantic from litellm.proxy._types import * from litellm.caching import DualCache, RedisCache @@ -2736,10 +2739,12 @@ class ProxyConfig: load_google_kms(use_google_kms=True) elif ( key_management_system - == KeyManagementSystem.AWS_SECRET_MANAGER.value + == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): ### LOAD FROM AWS SECRET MANAGER ### load_aws_secret_manager(use_aws_secret_manager=True) + elif key_management_system == KeyManagementSystem.AWS_KMS.value: + load_aws_kms(use_aws_kms=True) else: raise ValueError("Invalid Key Management System selected") key_management_settings = general_settings.get( @@ -2773,6 +2778,7 @@ class ProxyConfig: master_key = general_settings.get( "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) ) + if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) if not isinstance(master_key, str): @@ -4098,6 +4104,7 @@ async def chat_completion( user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): global general_settings, user_debug, proxy_logging_obj, llm_model_list + data = {} try: body = await request.body() diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py index a40b1dffa..9e6a777c4 100644 --- a/litellm/proxy/secret_managers/aws_secret_manager.py +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -8,6 +8,7 @@ Requires: * `pip install boto3>=1.28.57` """ +import boto3.session import litellm, os from typing import Optional from litellm.proxy._types import KeyManagementSystem @@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]): except Exception as e: raise e + + +def load_aws_kms(use_aws_kms: Optional[bool]): + if use_aws_kms is None or use_aws_kms is False: + return + try: + import boto3 + + validate_environment() + + # Create a Secrets Manager client + kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME")) + + litellm.secret_manager_client = kms_client + litellm._key_management_system = KeyManagementSystem.AWS_KMS + + except Exception as e: + raise e diff --git a/litellm/utils.py b/litellm/utils.py index ba6a37467..6ce41ad84 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7351,10 +7351,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: if custom_llm_provider == "databricks": return litellm.DatabricksConfig().get_required_params() - + elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_required_params() - + else: return [] @@ -10052,6 +10052,8 @@ def get_secret( ): key_management_system = litellm._key_management_system key_management_settings = litellm._key_management_settings + args = locals() + if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") @@ -10139,13 +10141,13 @@ def get_secret( key_manager = "local" if ( - key_manager == KeyManagementSystem.AZURE_KEY_VAULT + key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value or type(client).__module__ + "." + type(client).__name__ == "azure.keyvault.secrets._client.SecretClient" ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient secret = client.get_secret(secret_name).value elif ( - key_manager == KeyManagementSystem.GOOGLE_KMS + key_manager == KeyManagementSystem.GOOGLE_KMS.value or client.__class__.__name__ == "KeyManagementServiceClient" ): encrypted_secret: Any = os.getenv(secret_name) @@ -10173,6 +10175,25 @@ def get_secret( secret = response.plaintext.decode( "utf-8" ) # assumes the original value was encoded with utf-8 + elif key_manager == KeyManagementSystem.AWS_KMS.value: + """ + Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. + """ + encrypted_value = os.getenv(secret_name, None) + if encrypted_value is None: + raise Exception("encrypted value for AWS KMS cannot be None.") + # Decode the base64 encoded ciphertext + ciphertext_blob = base64.b64decode(encrypted_value) + + # Set up the parameters for the decrypt call + params = {"CiphertextBlob": ciphertext_blob} + + # Perform the decryption + response = client.decrypt(**params) + + # Extract and decode the plaintext + plaintext = response["Plaintext"] + secret = plaintext.decode("utf-8") elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: try: get_secret_value_response = client.get_secret_value( @@ -10193,10 +10214,14 @@ def get_secret( for k, v in secret_dict.items(): secret = v print_verbose(f"secret: {secret}") + elif key_manager == "local": + secret = os.getenv(secret_name) else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value except Exception as e: # check if it's in os.environ - print_verbose(f"An exception occurred - {str(e)}") + verbose_logger.error( + f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}" + ) secret = os.getenv(secret_name) try: secret_value_as_bool = ast.literal_eval(secret)