fix(google_kms.py): support enums for key management system

This commit is contained in:
Krrish Dholakia 2023-12-27 13:19:19 +05:30
parent 4cc59d21d0
commit 9ba520cc8b
6 changed files with 75 additions and 16 deletions

View file

@ -3,6 +3,7 @@ import threading, requests
from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache
from litellm._logging import set_verbose
from litellm.proxy._types import KeyManagementSystem
import httpx
input_callback: List[Union[str, Callable]] = []
@ -144,6 +145,7 @@ secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
#############################################

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator
import enum
from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
@ -175,6 +176,12 @@ class NewUserResponse(GenerateKeyResponse):
max_budget: Optional[float] = None
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
LOCAL = "local"
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
@ -183,6 +190,12 @@ class ConfigGeneralSettings(LiteLLMBase):
completion_model: Optional[str] = Field(
None, description="proxy level default model for all chat completion calls"
)
key_management_system: Optional[KeyManagementSystem] = Field(
None, description="key manager to load keys from / decrypt keys with"
)
use_google_kms: Optional[bool] = Field(
None, description="decrypt keys with google kms"
)
use_azure_key_vault: Optional[bool] = Field(
None, description="load keys from azure key vault"
)

View file

@ -415,6 +415,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
client = SecretClient(vault_url=KVUri, credential=credential)
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
else:
raise Exception(
f"Missing KVUri or client_id or client_secret or tenant_id from environment"
@ -691,10 +692,21 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings is None:
general_settings = {}
if general_settings:
### LOAD SECRET MANAGER ###
key_management_system = general_settings.get("key_management_system", None)
if key_management_system is not None:
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
### LOAD FROM AZURE KEY VAULT ###
load_from_azure_key_vault(use_azure_key_vault=True)
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
else:
raise ValueError("Invalid Key Management System selected")
### [DEPRECATED] LOAD FROM GOOGLE KMS ###
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)
### LOAD FROM AZURE KEY VAULT ###
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ###

View file

@ -9,6 +9,7 @@ Requires:
"""
import litellm, os
from typing import Optional
from litellm.proxy._types import KeyManagementSystem
def validate_environment():
@ -25,7 +26,7 @@ def validate_environment():
def load_google_kms(use_google_kms: Optional[bool]):
if use_google_kms is None or use_google_kms == False:
return
try:
from google.cloud import kms_v1 # type: ignore
validate_environment()
@ -33,4 +34,7 @@ def load_google_kms(use_google_kms: Optional[bool]):
# Create the KMS client
client = kms_v1.KeyManagementServiceClient()
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
except Exception as e:
raise e

View file

@ -7,7 +7,7 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, re
import sys, re, binascii
import litellm
import dotenv, json, traceback, threading, base64
import subprocess, os
@ -43,6 +43,7 @@ from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger
from .integrations.dynamodb import DyanmoDBLogger
from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .exceptions import (
@ -59,7 +60,7 @@ from .exceptions import (
BudgetExceededError,
UnprocessableEntityError,
)
from typing import cast, List, Dict, Union, Optional, Literal
from typing import cast, List, Dict, Union, Optional, Literal, Any
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -6331,24 +6332,45 @@ def litellm_telemetry(data):
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def get_secret(secret_name: str, default_value: Optional[str] = None):
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[str] = None,
):
key_management_system = litellm._key_management_system
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if (
type(client).__module__ + "." + type(client).__name__
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = retrieved_secret = client.get_secret(secret_name).value
elif client.__class__.__name__ == "KeyManagementServiceClient":
encrypted_secret = os.getenv(secret_name)
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
if not isinstance(encrypted_secret, bytes):
# If it's not, assume it's a string and encode it to bytes
ciphertext = eval(

6
mypy.ini Normal file
View file

@ -0,0 +1,6 @@
[mypy]
warn_return_any = False
ignore_missing_imports = False
[mypy-google.*]
ignore_missing_imports = True