feat(utils.py): support google kms for secret management

https://github.com/BerriAI/litellm/issues/1235
This commit is contained in:
Krrish Dholakia 2023-12-26 15:39:24 +05:30
parent e29dcf595e
commit 2070a785a4
9 changed files with 72 additions and 4 deletions

1
.gitignore vendored
View file

@ -30,3 +30,4 @@ litellm/proxy/log.txt
proxy_server_config_@.yaml proxy_server_config_@.yaml
.gitignore .gitignore
proxy_server_config_2.yaml proxy_server_config_2.yaml
litellm/proxy/secret_managers/credentials.json

View file

@ -143,6 +143,7 @@ allowed_fails: int = 0
secret_manager_client: Optional[ secret_manager_client: Optional[
Any Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
############################################# #############################################

View file

@ -97,6 +97,7 @@ from litellm.proxy.utils import (
ProxyLogging, ProxyLogging,
_cache_user_row, _cache_user_row,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.caching import DualCache from litellm.caching import DualCache
@ -690,13 +691,18 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings is None: if general_settings is None:
general_settings = {} general_settings = {}
if general_settings: if general_settings:
### LOAD FROM GOOGLE KMS ###
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### LOAD FROM AZURE KEY VAULT ### ### LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False) use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"): if database_url and database_url.startswith("os.environ/"):
print(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url) database_url = litellm.get_secret(database_url)
print(f"RETRIEVED DB URL: {database_url}")
prisma_setup(database_url=database_url) prisma_setup(database_url=database_url)
## COST TRACKING ## ## COST TRACKING ##
cost_tracking() cost_tracking()

View file

@ -0,0 +1,36 @@
"""
This is a file for the Google KMS integration
Relevant issue: https://github.com/BerriAI/litellm/issues/1235
Requires:
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
* `pip install google-cloud-kms`
"""
import litellm, os
from typing import Optional
def validate_environment():
if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ:
raise ValueError(
"Missing required environment variable - GOOGLE_APPLICATION_CREDENTIALS"
)
if "GOOGLE_KMS_RESOURCE_NAME" not in os.environ:
raise ValueError(
"Missing required environment variable - GOOGLE_KMS_RESOURCE_NAME"
)
def load_google_kms(use_google_kms: Optional[bool]):
if use_google_kms is None or use_google_kms == False:
return
from google.cloud import kms_v1
validate_environment()
# Create the KMS client
client = kms_v1.KeyManagementServiceClient()
litellm.secret_manager_client = client
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")

View file

@ -9,7 +9,7 @@
import sys, re import sys, re
import litellm import litellm
import dotenv, json, traceback, threading import dotenv, json, traceback, threading, base64
import subprocess, os import subprocess, os
import litellm, openai import litellm, openai
import itertools import itertools
@ -6341,10 +6341,33 @@ def get_secret(secret_name: str, default_value: Optional[str] = None):
== "azure.keyvault.secrets._client.SecretClient" == "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = retrieved_secret = client.get_secret(secret_name).value secret = retrieved_secret = client.get_secret(secret_name).value
elif client.__class__.__name__ == "KeyManagementServiceClient":
encrypted_secret = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
if not isinstance(encrypted_secret, bytes):
# If it's not, assume it's a string and encode it to bytes
ciphertext = eval(
encrypted_secret.encode()
) # assuming encrypted_secret is something like - b'\n$\x00D\xac\xb4/t)07\xe5\xf6..'
else:
ciphertext = encrypted_secret
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
else: # assume the default is infisicial client else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except: # check if it's in os.environ except Exception as e: # check if it's in os.environ
secret = os.environ.get(secret_name) secret = os.getenv(secret_name)
return secret return secret
else: else:
return os.environ.get(secret_name) return os.environ.get(secret_name)

View file

@ -37,7 +37,8 @@ proxy = [
extra_proxy = [ extra_proxy = [
"prisma", "prisma",
"azure-identity", "azure-identity",
"azure-keyvault-secrets" "azure-keyvault-secrets",
"google-cloud-kms"
] ]
proxy_otel = [ proxy_otel = [