From d8956e92558c61ed1a7d4e0bfa9df7a74699e8fe Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 16 Mar 2024 14:37:46 -0700 Subject: [PATCH 1/2] fix(utils.py): initial commit for aws secret manager support --- litellm/proxy/_types.py | 1 + .../secret_managers/aws_secret_manager.py | 40 +++++++++++++++++++ litellm/utils.py | 13 ++++++ 3 files changed, 54 insertions(+) create mode 100644 litellm/proxy/secret_managers/aws_secret_manager.py diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8a7efa1a1..2b564e079 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -387,6 +387,7 @@ class BudgetRequest(LiteLLMBase): class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" + AWS_SECRET_MANAGER = "aws_secret_manager" LOCAL = "local" diff --git a/litellm/proxy/secret_managers/aws_secret_manager.py b/litellm/proxy/secret_managers/aws_secret_manager.py new file mode 100644 index 000000000..a40b1dffa --- /dev/null +++ b/litellm/proxy/secret_managers/aws_secret_manager.py @@ -0,0 +1,40 @@ +""" +This is a file for the AWS Secret Manager Integration + +Relevant issue: https://github.com/BerriAI/litellm/issues/1883 + +Requires: +* `os.environ["AWS_REGION_NAME"], +* `pip install boto3>=1.28.57` +""" + +import litellm, os +from typing import Optional +from litellm.proxy._types import KeyManagementSystem + + +def validate_environment(): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + +def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]): + if use_aws_secret_manager is None or use_aws_secret_manager == False: + return + try: + import boto3 + from botocore.exceptions import ClientError + + validate_environment() + + # Create a Secrets Manager client + session = boto3.session.Session() + client = session.client( + service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME") + ) + + litellm.secret_manager_client = client + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + + except Exception as e: + raise e diff --git a/litellm/utils.py b/litellm/utils.py index 95b18421f..45b748661 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8332,6 +8332,19 @@ def get_secret( secret = response.plaintext.decode( "utf-8" ) # assumes the original value was encoded with utf-8 + elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: + try: + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + except Exception as e: + # For a list of exceptions thrown, see + # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + raise e + + # assume there is 1 secretstring per secret_name + for k, v in get_secret_value_response.items(): + secret = v else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value except Exception as e: # check if it's in os.environ From bc66ef9d5c73b8241b67bf5b6d70b107853dc696 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 16 Mar 2024 16:47:50 -0700 Subject: [PATCH 2/2] fix(utils.py): fix aws secret manager + support key_management_settings fixes the aws secret manager implementation and allows the user to set which keys they want to check thr ough it --- litellm/__init__.py | 3 ++- litellm/proxy/_types.py | 4 ++++ litellm/proxy/proxy_server.py | 19 ++++++++++++++++++- litellm/tests/test_secret_manager.py | 24 ++++++++++++++++++++++++ litellm/utils.py | 21 +++++++++++++++++++-- 5 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 litellm/tests/test_secret_manager.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 7eae39097..b14b07f5a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -3,7 +3,7 @@ import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any from litellm.caching import Cache from litellm._logging import set_verbose, _turn_on_debug, verbose_logger -from litellm.proxy._types import KeyManagementSystem +from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings import httpx import dotenv @@ -187,6 +187,7 @@ secret_manager_client: Optional[Any] = ( ) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None +_key_management_settings: Optional[KeyManagementSettings] = None #### PII MASKING #### output_parse_pii: bool = False ############################################# diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2b564e079..9afe36b4a 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -391,6 +391,10 @@ class KeyManagementSystem(enum.Enum): LOCAL = "local" +class KeyManagementSettings(LiteLLMBase): + hosted_keys: List + + class TeamDefaultSettings(LiteLLMBase): team_id: str diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a6510c7f2..044c23a2c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -98,6 +98,7 @@ from litellm.proxy.utils import ( _get_projected_spend_over_limit, ) from litellm.proxy.secret_managers.google_kms import load_google_kms +from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager import pydantic from litellm.proxy._types import * from litellm.caching import DualCache @@ -1089,6 +1090,8 @@ async def update_database( existing_token_obj = await user_api_key_cache.async_get_cache( key=hashed_token ) + if existing_token_obj is None: + return existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_user_obj is not None and isinstance(existing_user_obj, dict): existing_user_obj = LiteLLM_UserTable(**existing_user_obj) @@ -1417,7 +1420,8 @@ async def update_cache( else: hashed_token = token existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token) - existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) + if existing_token_obj is None: + return if existing_token_obj.user_id != user_id: # an end-user id was passed in end_user_id = user_id user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id] @@ -1903,8 +1907,21 @@ class ProxyConfig: elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: ### LOAD FROM GOOGLE KMS ### load_google_kms(use_google_kms=True) + elif ( + key_management_system + == KeyManagementSystem.AWS_SECRET_MANAGER.value + ): + ### LOAD FROM AWS SECRET MANAGER ### + load_aws_secret_manager(use_aws_secret_manager=True) else: raise ValueError("Invalid Key Management System selected") + key_management_settings = general_settings.get( + "key_management_settings", None + ) + if key_management_settings is not None: + litellm._key_management_settings = KeyManagementSettings( + **key_management_settings + ) ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms use_google_kms = general_settings.get("use_google_kms", False) load_google_kms(use_google_kms=use_google_kms) diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py new file mode 100644 index 000000000..7a411f185 --- /dev/null +++ b/litellm/tests/test_secret_manager.py @@ -0,0 +1,24 @@ +import sys, os, uuid +import time +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import get_secret +from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager + + +def test_aws_secret_manager(): + load_aws_secret_manager(use_aws_secret_manager=True) + + secret_val = get_secret("litellm_master_key") + + print(f"secret_val: {secret_val}") + + assert secret_val == "sk-1234" diff --git a/litellm/utils.py b/litellm/utils.py index 45b748661..d36ba4e1a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8288,8 +8288,10 @@ def get_secret( default_value: Optional[Union[str, bool]] = None, ): key_management_system = litellm._key_management_system + key_management_settings = litellm._key_management_settings if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") + try: if litellm.secret_manager_client is not None: try: @@ -8297,6 +8299,13 @@ def get_secret( key_manager = "local" if key_management_system is not None: key_manager = key_management_system.value + + if key_management_settings is not None: + if ( + secret_name not in key_management_settings.hosted_keys + ): # allow user to specify which keys to check in hosted key manager + key_manager = "local" + if ( key_manager == KeyManagementSystem.AZURE_KEY_VAULT or type(client).__module__ + "." + type(client).__name__ @@ -8337,17 +8346,25 @@ def get_secret( get_secret_value_response = client.get_secret_value( SecretId=secret_name ) + print_verbose( + f"get_secret_value_response: {get_secret_value_response}" + ) except Exception as e: + print_verbose(f"An error occurred - {str(e)}") # For a list of exceptions thrown, see # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html raise e - # assume there is 1 secretstring per secret_name - for k, v in get_secret_value_response.items(): + # assume there is 1 secret per secret_name + secret_dict = json.loads(get_secret_value_response["SecretString"]) + print_verbose(f"secret_dict: {secret_dict}") + for k, v in secret_dict.items(): secret = v + print_verbose(f"secret: {secret}") else: # assume the default is infisicial client secret = client.get_secret(secret_name).secret_value except Exception as e: # check if it's in os.environ + print_verbose(f"An exception occurred - {str(e)}") secret = os.getenv(secret_name) try: secret_value_as_bool = ast.literal_eval(secret)