fix(proxy_cli.py): support passing the database url as an encrypted kms key

This commit is contained in:
Krrish Dholakia 2024-06-10 15:48:27 -07:00
parent 58cce8a922
commit bee79f0b70
4 changed files with 74 additions and 14 deletions

View file

@ -21,7 +21,7 @@ import shutil
telemetry = None
def append_query_params(url, params):
def append_query_params(url, params) -> str:
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.debug(f"url: {url}")
@ -229,13 +229,29 @@ def run_server(
):
args = locals()
if local:
from proxy_server import app, save_worker_config, ProxyConfig
from proxy_server import (
app,
save_worker_config,
ProxyConfig,
KeyManagementSystem,
KeyManagementSettings,
load_from_azure_key_vault,
load_aws_kms,
load_aws_secret_manager,
load_google_kms,
)
else:
try:
from .proxy_server import (
app,
save_worker_config,
ProxyConfig,
KeyManagementSystem,
KeyManagementSettings,
load_from_azure_key_vault,
load_aws_kms,
load_aws_secret_manager,
load_google_kms,
)
except ImportError as e:
if "litellm[proxy]" in str(e):
@ -247,6 +263,12 @@ def run_server(
app,
save_worker_config,
ProxyConfig,
KeyManagementSystem,
KeyManagementSettings,
load_from_azure_key_vault,
load_aws_kms,
load_aws_secret_manager,
load_google_kms,
)
if version == True:
pkg_version = importlib.metadata.version("litellm")
@ -445,6 +467,40 @@ def run_server(
general_settings = _config.get("general_settings", {})
if general_settings is None:
general_settings = {}
if general_settings:
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get(
"key_management_system", None
)
if key_management_system is not None:
if (
key_management_system
== KeyManagementSystem.AZURE_KEY_VAULT.value
):
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
elif (
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
"key_management_settings", None
)
if key_management_settings is not None:
import litellm
litellm._key_management_settings = KeyManagementSettings(
**key_management_settings
)
database_url = general_settings.get("database_url", None)
db_connection_pool_limit = general_settings.get(
"database_connection_pool_limit", 100
@ -460,7 +516,7 @@ def run_server(
) # Adds the parent directory to the system path - for litellm local dev
import litellm
database_url = litellm.get_secret(database_url)
database_url = litellm.get_secret(database_url, default_value=None)
os.chdir(original_dir)
if database_url is not None and isinstance(database_url, str):
os.environ["DATABASE_URL"] = database_url
@ -470,13 +526,15 @@ def run_server(
or os.getenv("DIRECT_URL", None) is not None
):
try:
from litellm import get_secret
if os.getenv("DATABASE_URL", None) is not None:
### add connection pool + pool timeout args
params = {
"connection_limit": db_connection_pool_limit,
"pool_timeout": db_connection_timeout,
}
database_url = os.getenv("DATABASE_URL")
database_url = get_secret("DATABASE_URL", default_value=None)
modified_url = append_query_params(database_url, params)
os.environ["DATABASE_URL"] = modified_url
if os.getenv("DIRECT_URL", None) is not None: