diff --git a/litellm/__init__.py b/litellm/__init__.py index 7eae39097..b14b07f5a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -3,7 +3,7 @@ import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any from litellm.caching import Cache from litellm._logging import set_verbose, _turn_on_debug, verbose_logger -from litellm.proxy._types import KeyManagementSystem +from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings import httpx import dotenv @@ -187,6 +187,7 @@ secret_manager_client: Optional[Any] = ( ) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None +_key_management_settings: Optional[KeyManagementSettings] = None #### PII MASKING #### output_parse_pii: bool = False ############################################# diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a8c0c3d27..c8ec25704 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -387,9 +387,14 @@ class BudgetRequest(LiteLLMBase): class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" + AWS_SECRET_MANAGER = "aws_secret_manager" LOCAL = "local" +class KeyManagementSettings(LiteLLMBase): + hosted_keys: List + + class TeamDefaultSettings(LiteLLMBase): team_id: str diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a85cc69d7..a4da4e4a8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -98,6 +98,7 @@ from litellm.proxy.utils import ( _get_projected_spend_over_limit, ) from litellm.proxy.secret_managers.google_kms import load_google_kms +from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager import pydantic from litellm.proxy._types import * from litellm.caching import DualCache @@ -1089,6 +1090,8 @@ async def update_database( existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) + if existing_token_obj is None: + return existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) @@ -1417,7 +1420,8 @@ async def update_cache( else: hashed_token = token existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token) - existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) + if existing_token_obj is None: + return if existing_token_obj.user_id != user_id: # an end-user id was passed in end_user_id = user_id user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id] @@ -1905,8 +1909,21 @@ class ProxyConfig: elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: ### LOAD FROM GOOGLE KMS ### load_google_kms(use_google_kms=True) + elif ( + key_management_system + == KeyManagementSystem.AWS_SECRET_MANAGER.value + ): + ### LOAD FROM AWS SECRET MANAGER ### + load_aws_secret_manager(use_aws_secret_manager=True) else: raise ValueError("Invalid Key Management System selected") + key_management_settings = general_settings.get( + "key_management_settings", None + ) + if key_management_settings is not None: + litellm._key_management_settings = KeyManagementSettings( + **key_management_settings + ) ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms use_google_kms = general_settings.get("use_google_kms", False) load_google_kms(use_google_kms=use_google_kms) diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py new file mode 100644 index 000000000..a40b1dffa --- /dev/null +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -0,0 +1,40 @@ +""" +This is a file for the AWS Secret Manager Integration + +Relevant issue: https://github.com/BerriAI/litellm/issues/1883 + +Requires: +* `os.environ["AWS_REGION_NAME"], +* `pip install boto3>=1.28.57` +""" + +import litellm, os +from typing import Optional +from litellm.proxy._types import KeyManagementSystem + + +def validate_environment(): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + +def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]): + if use_aws_secret_manager is None or use_aws_secret_manager == False: + return + try: + import boto3 + from botocore.exceptions import ClientError + + validate_environment() + + # Create a Secrets Manager client + session = boto3.session.Session() + client = session.client( + service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME") + ) + + litellm.secret_manager_client = client + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + + except Exception as e: + raise e diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py new file mode 100644 index 000000000..7a411f185 --- /dev/null +++ b/litellm/tests/test_secret_manager.py @@ -0,0 +1,24 @@ +import sys, os, uuid +import time +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import get_secret +from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager + + +def test_aws_secret_manager(): + load_aws_secret_manager(use_aws_secret_manager=True) + + secret_val = get_secret("litellm_master_key") + + print(f"secret_val: {secret_val}") + + assert secret_val == "sk-1234" diff --git a/litellm/utils.py b/litellm/utils.py index 95b18421f..d36ba4e1a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8288,8 +8288,10 @@ def get_secret( default_value: Optional[Union[str, bool]] = None, ): key_management_system = litellm._key_management_system + key_management_settings = litellm._key_management_settings if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") + try: if litellm.secret_manager_client is not None: try: @@ -8297,6 +8299,13 @@ def get_secret( key_manager = "local" if key_management_system is not None: key_manager = key_management_system.value + + if key_management_settings is not None: + if ( + secret_name not in key_management_settings.hosted_keys + ): # allow user to specify which keys to check in hosted key manager + key_manager = "local" + if ( key_manager == KeyManagementSystem.AZURE_KEY_VAULT or type(client).__module__ + "." + type(client).__name__ @@ -8332,9 +8341,30 @@ def get_secret( secret = response.plaintext.decode( "utf-8" ) # assumes the original value was encoded with utf-8 + elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: + try: + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + print_verbose( + f"get_secret_value_response: {get_secret_value_response}" + ) + except Exception as e: + print_verbose(f"An error occurred - {str(e)}") + # For a list of exceptions thrown, see + # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + raise e + + # assume there is 1 secret per secret_name + secret_dict = json.loads(get_secret_value_response["SecretString"]) + print_verbose(f"secret_dict: {secret_dict}") + for k, v in secret_dict.items(): + secret = v + print_verbose(f"secret: {secret}") else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value except Exception as e: # check if it's in os.environ + print_verbose(f"An exception occurred - {str(e)}") secret = os.getenv(secret_name) try: secret_value_as_bool = ast.literal_eval(secret)