fix(utils.py): fix aws secret manager + support key_management_settings

fixes the aws secret manager implementation and allows the user to set which keys they want to check thr
ough it
This commit is contained in:
Krrish Dholakia 2024-03-16 16:47:50 -07:00
parent d8956e9255
commit bc66ef9d5c
5 changed files with 67 additions and 4 deletions

View file

@ -3,7 +3,7 @@ import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any
from litellm.caching import Cache
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
from litellm.proxy._types import KeyManagementSystem
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
import httpx
import dotenv
@ -187,6 +187,7 @@ secret_manager_client: Optional[Any] = (
)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: Optional[KeyManagementSettings] = None
#### PII MASKING ####
output_parse_pii: bool = False
#############################################

View file

@ -391,6 +391,10 @@ class KeyManagementSystem(enum.Enum):
LOCAL = "local"
class KeyManagementSettings(LiteLLMBase):
hosted_keys: List
class TeamDefaultSettings(LiteLLMBase):
team_id: str

View file

@ -98,6 +98,7 @@ from litellm.proxy.utils import (
_get_projected_spend_over_limit,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache
@ -1089,6 +1090,8 @@ async def update_database(
existing_token_obj = await user_api_key_cache.async_get_cache(
key=hashed_token
)
if existing_token_obj is None:
return
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
@ -1417,7 +1420,8 @@ async def update_cache(
else:
hashed_token = token
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if existing_token_obj is None:
return
if existing_token_obj.user_id != user_id: # an end-user id was passed in
end_user_id = user_id
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
@ -1903,8 +1907,21 @@ class ProxyConfig:
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
### LOAD FROM GOOGLE KMS ###
load_google_kms(use_google_kms=True)
elif (
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
"key_management_settings", None
)
if key_management_settings is not None:
litellm._key_management_settings = KeyManagementSettings(
**key_management_settings
)
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
use_google_kms = general_settings.get("use_google_kms", False)
load_google_kms(use_google_kms=use_google_kms)

View file

@ -0,0 +1,24 @@
import sys, os, uuid
import time
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import get_secret
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
def test_aws_secret_manager():
load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234"

View file

@ -8288,8 +8288,10 @@ def get_secret(
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
try:
if litellm.secret_manager_client is not None:
try:
@ -8297,6 +8299,13 @@ def get_secret(
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__
@ -8337,17 +8346,25 @@ def get_secret(
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secretstring per secret_name
for k, v in get_secret_value_response.items():
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}")
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)