fix(proxy_cli.py): support passing the database url as an encrypted kms key

This commit is contained in:
Krrish Dholakia 2024-06-10 15:48:27 -07:00
parent 58cce8a922
commit bee79f0b70
4 changed files with 74 additions and 14 deletions

View file

@ -57,9 +57,9 @@ router_settings:
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
# general_settings: general_settings:
# alerting: ["email"] alerting: ["email"]
# key_management_system: "aws_kms" key_management_system: "aws_kms"
# key_management_settings: key_management_settings:
# hosted_keys: ["LITELLM_MASTER_KEY"] hosted_keys: ["LITELLM_MASTER_KEY", "DATABASE_URL"]

View file

@ -21,7 +21,7 @@ import shutil
telemetry = None telemetry = None
def append_query_params(url, params): def append_query_params(url, params) -> str:
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.debug(f"url: {url}") verbose_proxy_logger.debug(f"url: {url}")
@ -229,13 +229,29 @@ def run_server(
): ):
args = locals() args = locals()
if local: if local:
from proxy_server import app, save_worker_config, ProxyConfig from proxy_server import (
app,
save_worker_config,
ProxyConfig,
KeyManagementSystem,
KeyManagementSettings,
load_from_azure_key_vault,
load_aws_kms,
load_aws_secret_manager,
load_google_kms,
)
else: else:
try: try:
from .proxy_server import ( from .proxy_server import (
app, app,
save_worker_config, save_worker_config,
ProxyConfig, ProxyConfig,
KeyManagementSystem,
KeyManagementSettings,
load_from_azure_key_vault,
load_aws_kms,
load_aws_secret_manager,
load_google_kms,
) )
except ImportError as e: except ImportError as e:
if "litellm[proxy]" in str(e): if "litellm[proxy]" in str(e):
@ -247,6 +263,12 @@ def run_server(
app, app,
save_worker_config, save_worker_config,
ProxyConfig, ProxyConfig,
KeyManagementSystem,
KeyManagementSettings,
load_from_azure_key_vault,
load_aws_kms,
load_aws_secret_manager,
load_google_kms,
) )
if version == True: if version == True:
pkg_version = importlib.metadata.version("litellm") pkg_version = importlib.metadata.version("litellm")
@ -445,6 +467,40 @@ def run_server(
general_settings = _config.get("general_settings", {}) general_settings = _config.get("general_settings", {})
if general_settings is None: if general_settings is None:
general_settings = {} general_settings = {}
if general_settings:
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get(
"key_management_system", None
)
if key_management_system is not None:
if (
key_management_system
== KeyManagementSystem.AZURE_KEY_VAULT.value
):
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
elif (
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
"key_management_settings", None
)
if key_management_settings is not None:
import litellm
litellm._key_management_settings = KeyManagementSettings(
**key_management_settings
)
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
db_connection_pool_limit = general_settings.get( db_connection_pool_limit = general_settings.get(
"database_connection_pool_limit", 100 "database_connection_pool_limit", 100
@ -460,7 +516,7 @@ def run_server(
) # Adds the parent directory to the system path - for litellm local dev ) # Adds the parent directory to the system path - for litellm local dev
import litellm import litellm
database_url = litellm.get_secret(database_url) database_url = litellm.get_secret(database_url, default_value=None)
os.chdir(original_dir) os.chdir(original_dir)
if database_url is not None and isinstance(database_url, str): if database_url is not None and isinstance(database_url, str):
os.environ["DATABASE_URL"] = database_url os.environ["DATABASE_URL"] = database_url
@ -470,13 +526,15 @@ def run_server(
or os.getenv("DIRECT_URL", None) is not None or os.getenv("DIRECT_URL", None) is not None
): ):
try: try:
from litellm import get_secret
if os.getenv("DATABASE_URL", None) is not None: if os.getenv("DATABASE_URL", None) is not None:
### add connection pool + pool timeout args ### add connection pool + pool timeout args
params = { params = {
"connection_limit": db_connection_pool_limit, "connection_limit": db_connection_pool_limit,
"pool_timeout": db_connection_timeout, "pool_timeout": db_connection_timeout,
} }
database_url = os.getenv("DATABASE_URL") database_url = get_secret("DATABASE_URL", default_value=None)
modified_url = append_query_params(database_url, params) modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url os.environ["DATABASE_URL"] = modified_url
if os.getenv("DIRECT_URL", None) is not None: if os.getenv("DIRECT_URL", None) is not None:

View file

@ -3895,7 +3895,7 @@ async def startup_event():
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None) master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
# check if DATABASE_URL in environment - load from there # check if DATABASE_URL in environment - load from there
if prisma_client is None: if prisma_client is None:
prisma_setup(database_url=os.getenv("DATABASE_URL")) prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
### LOAD CONFIG ### ### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG") worker_config = litellm.get_secret("WORKER_CONFIG")

View file

@ -10119,7 +10119,6 @@ def get_secret(
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings key_management_settings = litellm._key_management_settings
args = locals()
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "") secret_name = secret_name.replace("os.environ/", "")
@ -10248,13 +10247,16 @@ def get_secret(
""" """
encrypted_value = os.getenv(secret_name, None) encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None: if encrypted_value is None:
raise Exception("encrypted value for AWS KMS cannot be None.") raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext # Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value) ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call # Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob} params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption # Perform the decryption
response = client.decrypt(**params) response = client.decrypt(**params)
@ -10287,7 +10289,7 @@ def get_secret(
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
verbose_logger.error( verbose_logger.error(
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}" f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
) )
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
try: try: