forked from phoenix/litellm-mirror
Merge pull request #4054 from BerriAI/litellm_aws_kms_support
feat(aws_secret_manager.py): Support AWS KMS for Master Key encrption
This commit is contained in:
commit
f6a262122b
8 changed files with 94 additions and 11 deletions
|
@ -1,11 +1,31 @@
|
|||
# Secret Manager
|
||||
LiteLLM supports reading secrets from Azure Key Vault and Infisical
|
||||
|
||||
- AWS Key Managemenet Service
|
||||
- AWS Secret Manager
|
||||
- [Azure Key Vault](#azure-key-vault)
|
||||
- Google Key Management Service
|
||||
- [Infisical Secret Manager](#infisical-secret-manager)
|
||||
- [.env Files](#env-files)
|
||||
|
||||
## AWS Key Management Service
|
||||
|
||||
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
|
||||
|
||||
```bash
|
||||
export LITELLM_MASTER_KEY="djZ9xjVaZ..." # 👈 ENCRYPTED KEY
|
||||
export AWS_REGION_NAME="us-west-2"
|
||||
```
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY"] # 👈 WHICH KEYS ARE STORED ON KMS
|
||||
```
|
||||
|
||||
[**See Decryption Code**](https://github.com/BerriAI/litellm/blob/a2da2a8f168d45648b61279d4795d647d94f90c9/litellm/utils.py#L10182)
|
||||
|
||||
## AWS Secret Manager
|
||||
|
||||
Store your proxy keys in AWS Secret Manager.
|
||||
|
|
|
@ -56,8 +56,10 @@ router_settings:
|
|||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
json_logs: true
|
||||
|
||||
general_settings:
|
||||
alerting: ["email"]
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||
|
||||
|
|
|
@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum):
|
|||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||
LOCAL = "local"
|
||||
AWS_KMS = "aws_kms"
|
||||
|
||||
|
||||
class KeyManagementSettings(LiteLLMBase):
|
||||
|
|
|
@ -113,7 +113,10 @@ from litellm import (
|
|||
CreateFileRequest,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_secret_manager,
|
||||
load_aws_kms,
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
|
@ -2745,10 +2748,12 @@ class ProxyConfig:
|
|||
load_google_kms(use_google_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
@ -2782,6 +2787,7 @@ class ProxyConfig:
|
|||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
if not isinstance(master_key, str):
|
||||
|
@ -4130,6 +4136,7 @@ async def chat_completion(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
|
||||
data = {}
|
||||
try:
|
||||
body = await request.body()
|
||||
|
|
|
@ -8,7 +8,8 @@ Requires:
|
|||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import litellm, os
|
||||
import litellm
|
||||
import os
|
||||
from typing import Optional
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
|||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
litellm.secret_manager_client = kms_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -198,7 +198,11 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
|||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
pytest.fail(
|
||||
"An unexpected error occurred when running the thread, {}".format(
|
||||
run
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
added_message = await litellm.a_add_message(**data)
|
||||
|
@ -226,4 +230,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
|||
)
|
||||
assert isinstance(messages.data[0], Message)
|
||||
else:
|
||||
pytest.fail("An unexpected error occurred when running the thread")
|
||||
pytest.fail(
|
||||
"An unexpected error occurred when running the thread, {}".format(
|
||||
run
|
||||
)
|
||||
)
|
||||
|
|
|
@ -2539,6 +2539,7 @@ def test_replicate_custom_prompt_dict():
|
|||
"content": "what is yc write 1 paragraph",
|
||||
}
|
||||
],
|
||||
mock_response="Hello world",
|
||||
mock_response="hello world",
|
||||
repetition_penalty=0.1,
|
||||
num_retries=3,
|
||||
|
|
|
@ -10077,6 +10077,8 @@ def get_secret(
|
|||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
args = locals()
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
|
@ -10164,13 +10166,13 @@ def get_secret(
|
|||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
|
@ -10198,6 +10200,25 @@ def get_secret(
|
|||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception("encrypted value for AWS KMS cannot be None.")
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
|
@ -10218,10 +10239,14 @@ def get_secret(
|
|||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
print_verbose(f"An exception occurred - {str(e)}")
|
||||
verbose_logger.error(
|
||||
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue