Merge pull request #4054 from BerriAI/litellm_aws_kms_support

feat(aws_secret_manager.py): Support AWS KMS for Master Key encrption
This commit is contained in:
Krish Dholakia 2024-06-07 08:49:25 -07:00 committed by GitHub
commit 2c70e55f77
8 changed files with 94 additions and 11 deletions

View file

@ -1,11 +1,31 @@
# Secret Manager
LiteLLM supports reading secrets from Azure Key Vault and Infisical
- AWS Key Managemenet Service
- AWS Secret Manager
- [Azure Key Vault](#azure-key-vault)
- Google Key Management Service
- [Infisical Secret Manager](#infisical-secret-manager)
- [.env Files](#env-files)
## AWS Key Management Service
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
```bash
export LITELLM_MASTER_KEY="djZ9xjVaZ..." # 👈 ENCRYPTED KEY
export AWS_REGION_NAME="us-west-2"
```
```yaml
general_settings:
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"] # 👈 WHICH KEYS ARE STORED ON KMS
```
[**See Decryption Code**](https://github.com/BerriAI/litellm/blob/a2da2a8f168d45648b61279d4795d647d94f90c9/litellm/utils.py#L10182)
## AWS Secret Manager
Store your proxy keys in AWS Secret Manager.

View file

@ -56,8 +56,10 @@ router_settings:
litellm_settings:
success_callback: ["langfuse"]
json_logs: true
general_settings:
alerting: ["email"]
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum):
AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager"
LOCAL = "local"
AWS_KMS = "aws_kms"
class KeyManagementSettings(LiteLLMBase):

View file

@ -113,7 +113,10 @@ from litellm import (
CreateFileRequest,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
load_aws_kms,
)
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
@ -2745,10 +2748,12 @@ class ProxyConfig:
load_google_kms(use_google_kms=True)
elif (
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
@ -2782,6 +2787,7 @@ class ProxyConfig:
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
if not isinstance(master_key, str):
@ -4130,6 +4136,7 @@ async def chat_completion(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {}
try:
body = await request.body()

View file

@ -8,7 +8,8 @@ Requires:
* `pip install boto3>=1.28.57`
"""
import litellm, os
import litellm
import os
from typing import Optional
from litellm.proxy._types import KeyManagementSystem
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
except Exception as e:
raise e
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
litellm.secret_manager_client = kms_client
litellm._key_management_system = KeyManagementSystem.AWS_KMS
except Exception as e:
raise e

View file

@ -198,7 +198,11 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)
else:
added_message = await litellm.a_add_message(**data)
@ -226,4 +230,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)

View file

@ -2539,6 +2539,7 @@ def test_replicate_custom_prompt_dict():
"content": "what is yc write 1 paragraph",
}
],
mock_response="Hello world",
mock_response="hello world",
repetition_penalty=0.1,
num_retries=3,

View file

@ -10077,6 +10077,8 @@ def get_secret(
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
args = locals()
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
@ -10164,13 +10166,13 @@ def get_secret(
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
@ -10198,6 +10200,25 @@ def get_secret(
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception("encrypted value for AWS KMS cannot be None.")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
@ -10218,10 +10239,14 @@ def get_secret(
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}")
verbose_logger.error(
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)