mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(proxy_cli.py): support passing the database url as an encrypted kms key
This commit is contained in:
parent
58cce8a922
commit
bee79f0b70
4 changed files with 74 additions and 14 deletions
|
@ -57,9 +57,9 @@ router_settings:
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
|
|
||||||
# general_settings:
|
general_settings:
|
||||||
# alerting: ["email"]
|
alerting: ["email"]
|
||||||
# key_management_system: "aws_kms"
|
key_management_system: "aws_kms"
|
||||||
# key_management_settings:
|
key_management_settings:
|
||||||
# hosted_keys: ["LITELLM_MASTER_KEY"]
|
hosted_keys: ["LITELLM_MASTER_KEY", "DATABASE_URL"]
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ import shutil
|
||||||
telemetry = None
|
telemetry = None
|
||||||
|
|
||||||
|
|
||||||
def append_query_params(url, params):
|
def append_query_params(url, params) -> str:
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"url: {url}")
|
verbose_proxy_logger.debug(f"url: {url}")
|
||||||
|
@ -229,13 +229,29 @@ def run_server(
|
||||||
):
|
):
|
||||||
args = locals()
|
args = locals()
|
||||||
if local:
|
if local:
|
||||||
from proxy_server import app, save_worker_config, ProxyConfig
|
from proxy_server import (
|
||||||
|
app,
|
||||||
|
save_worker_config,
|
||||||
|
ProxyConfig,
|
||||||
|
KeyManagementSystem,
|
||||||
|
KeyManagementSettings,
|
||||||
|
load_from_azure_key_vault,
|
||||||
|
load_aws_kms,
|
||||||
|
load_aws_secret_manager,
|
||||||
|
load_google_kms,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from .proxy_server import (
|
from .proxy_server import (
|
||||||
app,
|
app,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
|
KeyManagementSystem,
|
||||||
|
KeyManagementSettings,
|
||||||
|
load_from_azure_key_vault,
|
||||||
|
load_aws_kms,
|
||||||
|
load_aws_secret_manager,
|
||||||
|
load_google_kms,
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if "litellm[proxy]" in str(e):
|
if "litellm[proxy]" in str(e):
|
||||||
|
@ -247,6 +263,12 @@ def run_server(
|
||||||
app,
|
app,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
|
KeyManagementSystem,
|
||||||
|
KeyManagementSettings,
|
||||||
|
load_from_azure_key_vault,
|
||||||
|
load_aws_kms,
|
||||||
|
load_aws_secret_manager,
|
||||||
|
load_google_kms,
|
||||||
)
|
)
|
||||||
if version == True:
|
if version == True:
|
||||||
pkg_version = importlib.metadata.version("litellm")
|
pkg_version = importlib.metadata.version("litellm")
|
||||||
|
@ -445,6 +467,40 @@ def run_server(
|
||||||
general_settings = _config.get("general_settings", {})
|
general_settings = _config.get("general_settings", {})
|
||||||
if general_settings is None:
|
if general_settings is None:
|
||||||
general_settings = {}
|
general_settings = {}
|
||||||
|
if general_settings:
|
||||||
|
### LOAD SECRET MANAGER ###
|
||||||
|
key_management_system = general_settings.get(
|
||||||
|
"key_management_system", None
|
||||||
|
)
|
||||||
|
if key_management_system is not None:
|
||||||
|
if (
|
||||||
|
key_management_system
|
||||||
|
== KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||||
|
):
|
||||||
|
### LOAD FROM AZURE KEY VAULT ###
|
||||||
|
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||||
|
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||||
|
### LOAD FROM GOOGLE KMS ###
|
||||||
|
load_google_kms(use_google_kms=True)
|
||||||
|
elif (
|
||||||
|
key_management_system
|
||||||
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||||
|
):
|
||||||
|
### LOAD FROM AWS SECRET MANAGER ###
|
||||||
|
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||||
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||||
|
load_aws_kms(use_aws_kms=True)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid Key Management System selected")
|
||||||
|
key_management_settings = general_settings.get(
|
||||||
|
"key_management_settings", None
|
||||||
|
)
|
||||||
|
if key_management_settings is not None:
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
**key_management_settings
|
||||||
|
)
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
db_connection_pool_limit = general_settings.get(
|
db_connection_pool_limit = general_settings.get(
|
||||||
"database_connection_pool_limit", 100
|
"database_connection_pool_limit", 100
|
||||||
|
@ -460,7 +516,7 @@ def run_server(
|
||||||
) # Adds the parent directory to the system path - for litellm local dev
|
) # Adds the parent directory to the system path - for litellm local dev
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
database_url = litellm.get_secret(database_url)
|
database_url = litellm.get_secret(database_url, default_value=None)
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
if database_url is not None and isinstance(database_url, str):
|
if database_url is not None and isinstance(database_url, str):
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
|
@ -470,13 +526,15 @@ def run_server(
|
||||||
or os.getenv("DIRECT_URL", None) is not None
|
or os.getenv("DIRECT_URL", None) is not None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
from litellm import get_secret
|
||||||
|
|
||||||
if os.getenv("DATABASE_URL", None) is not None:
|
if os.getenv("DATABASE_URL", None) is not None:
|
||||||
### add connection pool + pool timeout args
|
### add connection pool + pool timeout args
|
||||||
params = {
|
params = {
|
||||||
"connection_limit": db_connection_pool_limit,
|
"connection_limit": db_connection_pool_limit,
|
||||||
"pool_timeout": db_connection_timeout,
|
"pool_timeout": db_connection_timeout,
|
||||||
}
|
}
|
||||||
database_url = os.getenv("DATABASE_URL")
|
database_url = get_secret("DATABASE_URL", default_value=None)
|
||||||
modified_url = append_query_params(database_url, params)
|
modified_url = append_query_params(database_url, params)
|
||||||
os.environ["DATABASE_URL"] = modified_url
|
os.environ["DATABASE_URL"] = modified_url
|
||||||
if os.getenv("DIRECT_URL", None) is not None:
|
if os.getenv("DIRECT_URL", None) is not None:
|
||||||
|
|
|
@ -3895,7 +3895,7 @@ async def startup_event():
|
||||||
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||||
# check if DATABASE_URL in environment - load from there
|
# check if DATABASE_URL in environment - load from there
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
prisma_setup(database_url=os.getenv("DATABASE_URL"))
|
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
|
||||||
|
|
||||||
### LOAD CONFIG ###
|
### LOAD CONFIG ###
|
||||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||||
|
|
|
@ -10119,7 +10119,6 @@ def get_secret(
|
||||||
):
|
):
|
||||||
key_management_system = litellm._key_management_system
|
key_management_system = litellm._key_management_system
|
||||||
key_management_settings = litellm._key_management_settings
|
key_management_settings = litellm._key_management_settings
|
||||||
args = locals()
|
|
||||||
|
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
secret_name = secret_name.replace("os.environ/", "")
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
|
@ -10248,13 +10247,16 @@ def get_secret(
|
||||||
"""
|
"""
|
||||||
encrypted_value = os.getenv(secret_name, None)
|
encrypted_value = os.getenv(secret_name, None)
|
||||||
if encrypted_value is None:
|
if encrypted_value is None:
|
||||||
raise Exception("encrypted value for AWS KMS cannot be None.")
|
raise Exception(
|
||||||
|
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||||
|
secret_name
|
||||||
|
)
|
||||||
|
)
|
||||||
# Decode the base64 encoded ciphertext
|
# Decode the base64 encoded ciphertext
|
||||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||||
|
|
||||||
# Set up the parameters for the decrypt call
|
# Set up the parameters for the decrypt call
|
||||||
params = {"CiphertextBlob": ciphertext_blob}
|
params = {"CiphertextBlob": ciphertext_blob}
|
||||||
|
|
||||||
# Perform the decryption
|
# Perform the decryption
|
||||||
response = client.decrypt(**params)
|
response = client.decrypt(**params)
|
||||||
|
|
||||||
|
@ -10287,7 +10289,7 @@ def get_secret(
|
||||||
secret = client.get_secret(secret_name).secret_value
|
secret = client.get_secret(secret_name).secret_value
|
||||||
except Exception as e: # check if it's in os.environ
|
except Exception as e: # check if it's in os.environ
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
|
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
secret = os.getenv(secret_name)
|
secret = os.getenv(secret_name)
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue