fix(utils.py): fix aws secret manager + support key_management_settings

fixes the aws secret manager implementation and allows the user to set which keys they want to check thr
ough it
This commit is contained in:
Krrish Dholakia 2024-03-16 16:47:50 -07:00
parent d8956e9255
commit bc66ef9d5c
5 changed files with 67 additions and 4 deletions

View file

@ -3,7 +3,7 @@ import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
import httpx import httpx
import dotenv import dotenv
@ -187,6 +187,7 @@ secret_manager_client: Optional[Any] = (
) )
_google_kms_resource_name: Optional[str] = None _google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None _key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: Optional[KeyManagementSettings] = None
#### PII MASKING #### #### PII MASKING ####
output_parse_pii: bool = False output_parse_pii: bool = False
############################################# #############################################

View file

@ -391,6 +391,10 @@ class KeyManagementSystem(enum.Enum):
LOCAL = "local" LOCAL = "local"
class KeyManagementSettings(LiteLLMBase):
hosted_keys: List
class TeamDefaultSettings(LiteLLMBase): class TeamDefaultSettings(LiteLLMBase):
team_id: str team_id: str

View file

@ -98,6 +98,7 @@ from litellm.proxy.utils import (
_get_projected_spend_over_limit, _get_projected_spend_over_limit,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
import pydantic import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.caching import DualCache from litellm.caching import DualCache
@ -1089,6 +1090,8 @@ async def update_database(
existing_token_obj = await user_api_key_cache.async_get_cache( existing_token_obj = await user_api_key_cache.async_get_cache(
key=hashed_token key=hashed_token
) )
if existing_token_obj is None:
return
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_user_obj is not None and isinstance(existing_user_obj, dict): if existing_user_obj is not None and isinstance(existing_user_obj, dict):
existing_user_obj = LiteLLM_UserTable(**existing_user_obj) existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
@ -1417,7 +1420,8 @@ async def update_cache(
else: else:
hashed_token = token hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token) existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) if existing_token_obj is None:
return
if existing_token_obj.user_id != user_id: # an end-user id was passed in if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id] user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
@ -1903,8 +1907,21 @@ class ProxyConfig:
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ### ### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True) load_google_kms(use_google_kms=True)
elif (
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
else: else:
raise ValueError("Invalid Key Management System selected") raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
"key_management_settings", None
)
if key_management_settings is not None:
litellm._key_management_settings = KeyManagementSettings(
**key_management_settings
)
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False) use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms) load_google_kms(use_google_kms=use_google_kms)

View file

@ -0,0 +1,24 @@
import sys, os, uuid
import time
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import get_secret
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
def test_aws_secret_manager():
load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234"

View file

@ -8288,8 +8288,10 @@ def get_secret(
default_value: Optional[Union[str, bool]] = None, default_value: Optional[Union[str, bool]] = None,
): ):
key_management_system = litellm._key_management_system key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"): if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "") secret_name = secret_name.replace("os.environ/", "")
try: try:
if litellm.secret_manager_client is not None: if litellm.secret_manager_client is not None:
try: try:
@ -8297,6 +8299,13 @@ def get_secret(
key_manager = "local" key_manager = "local"
if key_management_system is not None: if key_management_system is not None:
key_manager = key_management_system.value key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if ( if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__ or type(client).__module__ + "." + type(client).__name__
@ -8337,17 +8346,25 @@ def get_secret(
get_secret_value_response = client.get_secret_value( get_secret_value_response = client.get_secret_value(
SecretId=secret_name SecretId=secret_name
) )
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e: except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see # For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e raise e
# assume there is 1 secretstring per secret_name # assume there is 1 secret per secret_name
for k, v in get_secret_value_response.items(): secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v secret = v
print_verbose(f"secret: {secret}")
else: # assume the default is infisicial client else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}")
secret = os.getenv(secret_name) secret = os.getenv(secret_name)
try: try:
secret_value_as_bool = ast.literal_eval(secret) secret_value_as_bool = ast.literal_eval(secret)