diff --git a/.gitignore b/.gitignore index f8449afea..29c296915 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ litellm/proxy/log.txt proxy_server_config_@.yaml .gitignore proxy_server_config_2.yaml +litellm/proxy/secret_managers/credentials.json diff --git a/litellm/__init__.py b/litellm/__init__.py index c5e1e7bd8..4cf303b35 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -143,6 +143,7 @@ allowed_fails: int = 0 secret_manager_client: Optional[ Any ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +_google_kms_resource_name: Optional[str] = None ############################################# diff --git a/litellm/proxy/_health_check_test_config.yaml b/litellm/proxy/example_config_yaml/_health_check_test_config.yaml similarity index 100% rename from litellm/proxy/_health_check_test_config.yaml rename to litellm/proxy/example_config_yaml/_health_check_test_config.yaml diff --git a/litellm/proxy/custom_auth.py b/litellm/proxy/example_config_yaml/custom_auth.py similarity index 100% rename from litellm/proxy/custom_auth.py rename to litellm/proxy/example_config_yaml/custom_auth.py diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/example_config_yaml/custom_callbacks.py similarity index 100% rename from litellm/proxy/custom_callbacks.py rename to litellm/proxy/example_config_yaml/custom_callbacks.py diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 48881d24a..70b32cbf3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -97,6 +97,7 @@ from litellm.proxy.utils import ( ProxyLogging, _cache_user_row, ) +from litellm.proxy.secret_managers.google_kms import load_google_kms import pydantic from litellm.proxy._types import * from litellm.caching import DualCache @@ -690,13 +691,18 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if general_settings is None: general_settings = {} if general_settings: + ### LOAD FROM GOOGLE KMS ### + use_google_kms = general_settings.get("use_google_kms", False) + load_google_kms(use_google_kms=use_google_kms) ### LOAD FROM AZURE KEY VAULT ### use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) if database_url and database_url.startswith("os.environ/"): + print(f"GOING INTO LITELLM.GET_SECRET!") database_url = litellm.get_secret(database_url) + print(f"RETRIEVED DB URL: {database_url}") prisma_setup(database_url=database_url) ## COST TRACKING ## cost_tracking() diff --git a/litellm/proxy/secret_managers/google_kms.py b/litellm/proxy/secret_managers/google_kms.py new file mode 100644 index 000000000..748a11f8e --- /dev/null +++ b/litellm/proxy/secret_managers/google_kms.py @@ -0,0 +1,36 @@ +""" +This is a file for the Google KMS integration + +Relevant issue: https://github.com/BerriAI/litellm/issues/1235 + +Requires: +* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]` +* `pip install google-cloud-kms` +""" +import litellm, os +from typing import Optional + + +def validate_environment(): + if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ: + raise ValueError( + "Missing required environment variable - GOOGLE_APPLICATION_CREDENTIALS" + ) + if "GOOGLE_KMS_RESOURCE_NAME" not in os.environ: + raise ValueError( + "Missing required environment variable - GOOGLE_KMS_RESOURCE_NAME" + ) + + +def load_google_kms(use_google_kms: Optional[bool]): + if use_google_kms is None or use_google_kms == False: + return + + from google.cloud import kms_v1 + + validate_environment() + + # Create the KMS client + client = kms_v1.KeyManagementServiceClient() + litellm.secret_manager_client = client + litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME") diff --git a/litellm/utils.py b/litellm/utils.py index 85775da11..260bd8d6a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9,7 +9,7 @@ import sys, re import litellm -import dotenv, json, traceback, threading +import dotenv, json, traceback, threading, base64 import subprocess, os import litellm, openai import itertools @@ -6341,10 +6341,33 @@ def get_secret(secret_name: str, default_value: Optional[str] = None): == "azure.keyvault.secrets._client.SecretClient" ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient secret = retrieved_secret = client.get_secret(secret_name).value + elif client.__class__.__name__ == "KeyManagementServiceClient": + encrypted_secret = os.getenv(secret_name) + if encrypted_secret is None: + raise ValueError( + f"Google KMS requires the encrypted secret to be in the environment!" + ) + if not isinstance(encrypted_secret, bytes): + # If it's not, assume it's a string and encode it to bytes + ciphertext = eval( + encrypted_secret.encode() + ) # assuming encrypted_secret is something like - b'\n$\x00D\xac\xb4/t)07\xe5\xf6..' + else: + ciphertext = encrypted_secret + + response = client.decrypt( + request={ + "name": litellm._google_kms_resource_name, + "ciphertext": ciphertext, + } + ) + secret = response.plaintext.decode( + "utf-8" + ) # assumes the original value was encoded with utf-8 else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value - except: # check if it's in os.environ - secret = os.environ.get(secret_name) + except Exception as e: # check if it's in os.environ + secret = os.getenv(secret_name) return secret else: return os.environ.get(secret_name) diff --git a/pyproject.toml b/pyproject.toml index ac6be5f50..cc489f4c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,8 @@ proxy = [ extra_proxy = [ "prisma", "azure-identity", - "azure-keyvault-secrets" + "azure-keyvault-secrets", + "google-cloud-kms" ] proxy_otel = [