mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_cli.py): support passing the database url as an encrypted kms key
This commit is contained in:
parent
6306914e56
commit
e4dbb9b2db
4 changed files with 74 additions and 14 deletions
|
@ -21,7 +21,7 @@ import shutil
|
|||
telemetry = None
|
||||
|
||||
|
||||
def append_query_params(url, params):
|
||||
def append_query_params(url, params) -> str:
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.debug(f"url: {url}")
|
||||
|
@ -229,13 +229,29 @@ def run_server(
|
|||
):
|
||||
args = locals()
|
||||
if local:
|
||||
from proxy_server import app, save_worker_config, ProxyConfig
|
||||
from proxy_server import (
|
||||
app,
|
||||
save_worker_config,
|
||||
ProxyConfig,
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
load_from_azure_key_vault,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_google_kms,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from .proxy_server import (
|
||||
app,
|
||||
save_worker_config,
|
||||
ProxyConfig,
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
load_from_azure_key_vault,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_google_kms,
|
||||
)
|
||||
except ImportError as e:
|
||||
if "litellm[proxy]" in str(e):
|
||||
|
@ -247,6 +263,12 @@ def run_server(
|
|||
app,
|
||||
save_worker_config,
|
||||
ProxyConfig,
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
load_from_azure_key_vault,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_google_kms,
|
||||
)
|
||||
if version == True:
|
||||
pkg_version = importlib.metadata.version("litellm")
|
||||
|
@ -445,6 +467,40 @@ def run_server(
|
|||
general_settings = _config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### LOAD SECRET MANAGER ###
|
||||
key_management_system = general_settings.get(
|
||||
"key_management_system", None
|
||||
)
|
||||
if key_management_system is not None:
|
||||
if (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
):
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||
### LOAD FROM GOOGLE KMS ###
|
||||
load_google_kms(use_google_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
"key_management_settings", None
|
||||
)
|
||||
if key_management_settings is not None:
|
||||
import litellm
|
||||
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
**key_management_settings
|
||||
)
|
||||
database_url = general_settings.get("database_url", None)
|
||||
db_connection_pool_limit = general_settings.get(
|
||||
"database_connection_pool_limit", 100
|
||||
|
@ -460,7 +516,7 @@ def run_server(
|
|||
) # Adds the parent directory to the system path - for litellm local dev
|
||||
import litellm
|
||||
|
||||
database_url = litellm.get_secret(database_url)
|
||||
database_url = litellm.get_secret(database_url, default_value=None)
|
||||
os.chdir(original_dir)
|
||||
if database_url is not None and isinstance(database_url, str):
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
|
@ -470,13 +526,15 @@ def run_server(
|
|||
or os.getenv("DIRECT_URL", None) is not None
|
||||
):
|
||||
try:
|
||||
from litellm import get_secret
|
||||
|
||||
if os.getenv("DATABASE_URL", None) is not None:
|
||||
### add connection pool + pool timeout args
|
||||
params = {
|
||||
"connection_limit": db_connection_pool_limit,
|
||||
"pool_timeout": db_connection_timeout,
|
||||
}
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
database_url = get_secret("DATABASE_URL", default_value=None)
|
||||
modified_url = append_query_params(database_url, params)
|
||||
os.environ["DATABASE_URL"] = modified_url
|
||||
if os.getenv("DIRECT_URL", None) is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue