mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(aws_secret_manager.py): allows user to keep a hash of the proxy master key in their env
This commit is contained in:
parent
707a41fd6c
commit
a2da2a8f16
5 changed files with 62 additions and 8 deletions
|
@ -56,8 +56,10 @@ router_settings:
|
|||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
json_logs: true
|
||||
|
||||
general_settings:
|
||||
alerting: ["email"]
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||
|
||||
|
|
|
@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum):
|
|||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||
LOCAL = "local"
|
||||
AWS_KMS = "aws_kms"
|
||||
|
||||
|
||||
class KeyManagementSettings(LiteLLMBase):
|
||||
|
|
|
@ -112,7 +112,10 @@ from litellm import (
|
|||
CreateFileRequest,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_secret_manager,
|
||||
load_aws_kms,
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
|
@ -2736,10 +2739,12 @@ class ProxyConfig:
|
|||
load_google_kms(use_google_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
@ -2773,6 +2778,7 @@ class ProxyConfig:
|
|||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
if not isinstance(master_key, str):
|
||||
|
@ -4098,6 +4104,7 @@ async def chat_completion(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
|
||||
data = {}
|
||||
try:
|
||||
body = await request.body()
|
||||
|
|
|
@ -8,6 +8,7 @@ Requires:
|
|||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import boto3.session
|
||||
import litellm, os
|
||||
from typing import Optional
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
|||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
litellm.secret_manager_client = kms_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -7351,10 +7351,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
|
|||
|
||||
if custom_llm_provider == "databricks":
|
||||
return litellm.DatabricksConfig().get_required_params()
|
||||
|
||||
|
||||
elif custom_llm_provider == "ollama":
|
||||
return litellm.OllamaConfig().get_required_params()
|
||||
|
||||
|
||||
else:
|
||||
return []
|
||||
|
||||
|
@ -10052,6 +10052,8 @@ def get_secret(
|
|||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
args = locals()
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
|
@ -10139,13 +10141,13 @@ def get_secret(
|
|||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
|
@ -10173,6 +10175,25 @@ def get_secret(
|
|||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception("encrypted value for AWS KMS cannot be None.")
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
|
@ -10193,10 +10214,14 @@ def get_secret(
|
|||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
print_verbose(f"An exception occurred - {str(e)}")
|
||||
verbose_logger.error(
|
||||
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue