forked from phoenix/litellm-mirror
fix(google_kms.py): support enums for key management system
This commit is contained in:
parent
4cc59d21d0
commit
9ba520cc8b
6 changed files with 75 additions and 16 deletions
|
@ -3,6 +3,7 @@ import threading, requests
|
|||
from typing import Callable, List, Optional, Dict, Union, Any
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import set_verbose
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
import httpx
|
||||
|
||||
input_callback: List[Union[str, Callable]] = []
|
||||
|
@ -144,6 +145,7 @@ secret_manager_client: Optional[
|
|||
Any
|
||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
_google_kms_resource_name: Optional[str] = None
|
||||
_key_management_system: Optional[KeyManagementSystem] = None
|
||||
#############################################
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
import enum
|
||||
from typing import Optional, List, Union, Dict, Literal
|
||||
from datetime import datetime
|
||||
import uuid, json
|
||||
|
@ -175,6 +176,12 @@ class NewUserResponse(GenerateKeyResponse):
|
|||
max_budget: Optional[float] = None
|
||||
|
||||
|
||||
class KeyManagementSystem(enum.Enum):
|
||||
GOOGLE_KMS = "google_kms"
|
||||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
|
@ -183,6 +190,12 @@ class ConfigGeneralSettings(LiteLLMBase):
|
|||
completion_model: Optional[str] = Field(
|
||||
None, description="proxy level default model for all chat completion calls"
|
||||
)
|
||||
key_management_system: Optional[KeyManagementSystem] = Field(
|
||||
None, description="key manager to load keys from / decrypt keys with"
|
||||
)
|
||||
use_google_kms: Optional[bool] = Field(
|
||||
None, description="decrypt keys with google kms"
|
||||
)
|
||||
use_azure_key_vault: Optional[bool] = Field(
|
||||
None, description="load keys from azure key vault"
|
||||
)
|
||||
|
|
|
@ -415,6 +415,7 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
|||
client = SecretClient(vault_url=KVUri, credential=credential)
|
||||
|
||||
litellm.secret_manager_client = client
|
||||
litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
|
||||
else:
|
||||
raise Exception(
|
||||
f"Missing KVUri or client_id or client_secret or tenant_id from environment"
|
||||
|
@ -691,10 +692,21 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### LOAD SECRET MANAGER ###
|
||||
key_management_system = general_settings.get("key_management_system", None)
|
||||
if key_management_system is not None:
|
||||
if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
load_from_azure_key_vault(use_azure_key_vault=True)
|
||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||
### LOAD FROM GOOGLE KMS ###
|
||||
load_google_kms(use_google_kms=True)
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
### [DEPRECATED] LOAD FROM GOOGLE KMS ###
|
||||
use_google_kms = general_settings.get("use_google_kms", False)
|
||||
load_google_kms(use_google_kms=use_google_kms)
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
### [DEPRECATED] LOAD FROM AZURE KEY VAULT ###
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
### CONNECT TO DATABASE ###
|
||||
|
|
|
@ -9,6 +9,7 @@ Requires:
|
|||
"""
|
||||
import litellm, os
|
||||
from typing import Optional
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
def validate_environment():
|
||||
|
@ -25,7 +26,7 @@ def validate_environment():
|
|||
def load_google_kms(use_google_kms: Optional[bool]):
|
||||
if use_google_kms is None or use_google_kms == False:
|
||||
return
|
||||
|
||||
try:
|
||||
from google.cloud import kms_v1 # type: ignore
|
||||
|
||||
validate_environment()
|
||||
|
@ -33,4 +34,7 @@ def load_google_kms(use_google_kms: Optional[bool]):
|
|||
# Create the KMS client
|
||||
client = kms_v1.KeyManagementServiceClient()
|
||||
litellm.secret_manager_client = client
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
|
||||
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import sys, re
|
||||
import sys, re, binascii
|
||||
import litellm
|
||||
import dotenv, json, traceback, threading, base64
|
||||
import subprocess, os
|
||||
|
@ -43,6 +43,7 @@ from .integrations.custom_logger import CustomLogger
|
|||
from .integrations.langfuse import LangFuseLogger
|
||||
from .integrations.dynamodb import DyanmoDBLogger
|
||||
from .integrations.litedebugger import LiteDebugger
|
||||
from .proxy._types import KeyManagementSystem
|
||||
from openai import OpenAIError as OriginalError
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from .exceptions import (
|
||||
|
@ -59,7 +60,7 @@ from .exceptions import (
|
|||
BudgetExceededError,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
from typing import cast, List, Dict, Union, Optional, Literal
|
||||
from typing import cast, List, Dict, Union, Optional, Literal, Any
|
||||
from .caching import Cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
@ -6331,24 +6332,45 @@ def litellm_telemetry(data):
|
|||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def get_secret(secret_name: str, default_value: Optional[str] = None):
|
||||
def _is_base64(s):
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[str] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
if (
|
||||
type(client).__module__ + "." + type(client).__name__
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = retrieved_secret = client.get_secret(secret_name).value
|
||||
elif client.__class__.__name__ == "KeyManagementServiceClient":
|
||||
encrypted_secret = os.getenv(secret_name)
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag == True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
if not isinstance(encrypted_secret, bytes):
|
||||
# If it's not, assume it's a string and encode it to bytes
|
||||
ciphertext = eval(
|
||||
|
|
6
mypy.ini
Normal file
6
mypy.ini
Normal file
|
@ -0,0 +1,6 @@
|
|||
[mypy]
|
||||
warn_return_any = False
|
||||
ignore_missing_imports = False
|
||||
|
||||
[mypy-google.*]
|
||||
ignore_missing_imports = True
|
Loading…
Add table
Add a link
Reference in a new issue