mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(proxy_cli.py): support passing the database url as an encrypted kms key
This commit is contained in:
parent
58cce8a922
commit
bee79f0b70
4 changed files with 74 additions and 14 deletions
|
@ -57,9 +57,9 @@ router_settings:
|
|||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
|
||||
# general_settings:
|
||||
# alerting: ["email"]
|
||||
# key_management_system: "aws_kms"
|
||||
# key_management_settings:
|
||||
# hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||
general_settings:
|
||||
alerting: ["email"]
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY", "DATABASE_URL"]
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import shutil
|
|||
telemetry = None
|
||||
|
||||
|
||||
def append_query_params(url, params):
|
||||
def append_query_params(url, params) -> str:
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.debug(f"url: {url}")
|
||||
|
@ -229,13 +229,29 @@ def run_server(
|
|||
):
|
||||
args = locals()
|
||||
if local:
|
||||
from proxy_server import app, save_worker_config, ProxyConfig
|
||||
from proxy_server import (
|
||||
app,
|
||||
save_worker_config,
|
||||
ProxyConfig,
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
load_from_azure_key_vault,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_google_kms,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from .proxy_server import (
|
||||
app,
|
||||
save_worker_config,
|
||||
ProxyConfig,
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
load_from_azure_key_vault,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_google_kms,
|
||||
)
|
||||
except ImportError as e:
|
||||
if "litellm[proxy]" in str(e):
|
||||
|
@ -247,6 +263,12 @@ def run_server(
|
|||
app,
|
||||
save_worker_config,
|
||||
ProxyConfig,
|
||||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
load_from_azure_key_vault,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_google_kms,
|
||||
)
|
||||
if version == True:
|
||||
pkg_version = importlib.metadata.version("litellm")
|
||||
|
@ -445,6 +467,40 @@ def run_server(
|
|||
general_settings = _config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### LOAD SECRET MANAGER ###
|
||||
key_management_system = general_settings.get(
|
||||
"key_management_system", None
|
||||
)
|
||||
if key_management_system is not None:
|
||||
if (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
):
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||
### LOAD FROM GOOGLE KMS ###
|
||||
load_google_kms(use_google_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
"key_management_settings", None
|
||||
)
|
||||
if key_management_settings is not None:
|
||||
import litellm
|
||||
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
**key_management_settings
|
||||
)
|
||||
database_url = general_settings.get("database_url", None)
|
||||
db_connection_pool_limit = general_settings.get(
|
||||
"database_connection_pool_limit", 100
|
||||
|
@ -460,7 +516,7 @@ def run_server(
|
|||
) # Adds the parent directory to the system path - for litellm local dev
|
||||
import litellm
|
||||
|
||||
database_url = litellm.get_secret(database_url)
|
||||
database_url = litellm.get_secret(database_url, default_value=None)
|
||||
os.chdir(original_dir)
|
||||
if database_url is not None and isinstance(database_url, str):
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
|
@ -470,13 +526,15 @@ def run_server(
|
|||
or os.getenv("DIRECT_URL", None) is not None
|
||||
):
|
||||
try:
|
||||
from litellm import get_secret
|
||||
|
||||
if os.getenv("DATABASE_URL", None) is not None:
|
||||
### add connection pool + pool timeout args
|
||||
params = {
|
||||
"connection_limit": db_connection_pool_limit,
|
||||
"pool_timeout": db_connection_timeout,
|
||||
}
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
database_url = get_secret("DATABASE_URL", default_value=None)
|
||||
modified_url = append_query_params(database_url, params)
|
||||
os.environ["DATABASE_URL"] = modified_url
|
||||
if os.getenv("DIRECT_URL", None) is not None:
|
||||
|
|
|
@ -3895,7 +3895,7 @@ async def startup_event():
|
|||
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
prisma_setup(database_url=os.getenv("DATABASE_URL"))
|
||||
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||
|
|
|
@ -10119,7 +10119,6 @@ def get_secret(
|
|||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
args = locals()
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
@ -10248,13 +10247,16 @@ def get_secret(
|
|||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception("encrypted value for AWS KMS cannot be None.")
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
|
@ -10287,7 +10289,7 @@ def get_secret(
|
|||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue