mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(aws_secret_manager.py): allows user to keep a hash of the proxy master key in their env
This commit is contained in:
parent
707a41fd6c
commit
a2da2a8f16
5 changed files with 62 additions and 8 deletions
|
@ -56,8 +56,10 @@ router_settings:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
json_logs: true
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["email"]
|
alerting: ["email"]
|
||||||
|
key_management_system: "aws_kms"
|
||||||
|
key_management_settings:
|
||||||
|
hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||||
|
|
||||||
|
|
|
@ -946,6 +946,7 @@ class KeyManagementSystem(enum.Enum):
|
||||||
AZURE_KEY_VAULT = "azure_key_vault"
|
AZURE_KEY_VAULT = "azure_key_vault"
|
||||||
AWS_SECRET_MANAGER = "aws_secret_manager"
|
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||||
LOCAL = "local"
|
LOCAL = "local"
|
||||||
|
AWS_KMS = "aws_kms"
|
||||||
|
|
||||||
|
|
||||||
class KeyManagementSettings(LiteLLMBase):
|
class KeyManagementSettings(LiteLLMBase):
|
||||||
|
|
|
@ -112,7 +112,10 @@ from litellm import (
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
|
load_aws_secret_manager,
|
||||||
|
load_aws_kms,
|
||||||
|
)
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
|
@ -2736,10 +2739,12 @@ class ProxyConfig:
|
||||||
load_google_kms(use_google_kms=True)
|
load_google_kms(use_google_kms=True)
|
||||||
elif (
|
elif (
|
||||||
key_management_system
|
key_management_system
|
||||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||||
):
|
):
|
||||||
### LOAD FROM AWS SECRET MANAGER ###
|
### LOAD FROM AWS SECRET MANAGER ###
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||||
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||||
|
load_aws_kms(use_aws_kms=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid Key Management System selected")
|
raise ValueError("Invalid Key Management System selected")
|
||||||
key_management_settings = general_settings.get(
|
key_management_settings = general_settings.get(
|
||||||
|
@ -2773,6 +2778,7 @@ class ProxyConfig:
|
||||||
master_key = general_settings.get(
|
master_key = general_settings.get(
|
||||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
master_key = litellm.get_secret(master_key)
|
master_key = litellm.get_secret(master_key)
|
||||||
if not isinstance(master_key, str):
|
if not isinstance(master_key, str):
|
||||||
|
@ -4098,6 +4104,7 @@ async def chat_completion(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
|
@ -8,6 +8,7 @@ Requires:
|
||||||
* `pip install boto3>=1.28.57`
|
* `pip install boto3>=1.28.57`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import boto3.session
|
||||||
import litellm, os
|
import litellm, os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from litellm.proxy._types import KeyManagementSystem
|
from litellm.proxy._types import KeyManagementSystem
|
||||||
|
@ -38,3 +39,21 @@ def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||||
|
if use_aws_kms is None or use_aws_kms is False:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
validate_environment()
|
||||||
|
|
||||||
|
# Create a Secrets Manager client
|
||||||
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||||
|
|
||||||
|
litellm.secret_manager_client = kms_client
|
||||||
|
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
|
@ -7351,10 +7351,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
|
||||||
|
|
||||||
if custom_llm_provider == "databricks":
|
if custom_llm_provider == "databricks":
|
||||||
return litellm.DatabricksConfig().get_required_params()
|
return litellm.DatabricksConfig().get_required_params()
|
||||||
|
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
return litellm.OllamaConfig().get_required_params()
|
return litellm.OllamaConfig().get_required_params()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -10052,6 +10052,8 @@ def get_secret(
|
||||||
):
|
):
|
||||||
key_management_system = litellm._key_management_system
|
key_management_system = litellm._key_management_system
|
||||||
key_management_settings = litellm._key_management_settings
|
key_management_settings = litellm._key_management_settings
|
||||||
|
args = locals()
|
||||||
|
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
secret_name = secret_name.replace("os.environ/", "")
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
|
|
||||||
|
@ -10139,13 +10141,13 @@ def get_secret(
|
||||||
key_manager = "local"
|
key_manager = "local"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||||
or type(client).__module__ + "." + type(client).__name__
|
or type(client).__module__ + "." + type(client).__name__
|
||||||
== "azure.keyvault.secrets._client.SecretClient"
|
== "azure.keyvault.secrets._client.SecretClient"
|
||||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||||
secret = client.get_secret(secret_name).value
|
secret = client.get_secret(secret_name).value
|
||||||
elif (
|
elif (
|
||||||
key_manager == KeyManagementSystem.GOOGLE_KMS
|
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||||
):
|
):
|
||||||
encrypted_secret: Any = os.getenv(secret_name)
|
encrypted_secret: Any = os.getenv(secret_name)
|
||||||
|
@ -10173,6 +10175,25 @@ def get_secret(
|
||||||
secret = response.plaintext.decode(
|
secret = response.plaintext.decode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
) # assumes the original value was encoded with utf-8
|
) # assumes the original value was encoded with utf-8
|
||||||
|
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||||
|
"""
|
||||||
|
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||||
|
"""
|
||||||
|
encrypted_value = os.getenv(secret_name, None)
|
||||||
|
if encrypted_value is None:
|
||||||
|
raise Exception("encrypted value for AWS KMS cannot be None.")
|
||||||
|
# Decode the base64 encoded ciphertext
|
||||||
|
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||||
|
|
||||||
|
# Set up the parameters for the decrypt call
|
||||||
|
params = {"CiphertextBlob": ciphertext_blob}
|
||||||
|
|
||||||
|
# Perform the decryption
|
||||||
|
response = client.decrypt(**params)
|
||||||
|
|
||||||
|
# Extract and decode the plaintext
|
||||||
|
plaintext = response["Plaintext"]
|
||||||
|
secret = plaintext.decode("utf-8")
|
||||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||||
try:
|
try:
|
||||||
get_secret_value_response = client.get_secret_value(
|
get_secret_value_response = client.get_secret_value(
|
||||||
|
@ -10193,10 +10214,14 @@ def get_secret(
|
||||||
for k, v in secret_dict.items():
|
for k, v in secret_dict.items():
|
||||||
secret = v
|
secret = v
|
||||||
print_verbose(f"secret: {secret}")
|
print_verbose(f"secret: {secret}")
|
||||||
|
elif key_manager == "local":
|
||||||
|
secret = os.getenv(secret_name)
|
||||||
else: # assume the default is infisicial client
|
else: # assume the default is infisicial client
|
||||||
secret = client.get_secret(secret_name).secret_value
|
secret = client.get_secret(secret_name).secret_value
|
||||||
except Exception as e: # check if it's in os.environ
|
except Exception as e: # check if it's in os.environ
|
||||||
print_verbose(f"An exception occurred - {str(e)}")
|
verbose_logger.error(
|
||||||
|
f"An exception occurred - {str(e)}\n\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
secret = os.getenv(secret_name)
|
secret = os.getenv(secret_name)
|
||||||
try:
|
try:
|
||||||
secret_value_as_bool = ast.literal_eval(secret)
|
secret_value_as_bool = ast.literal_eval(secret)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue