forked from phoenix/litellm-mirror
Merge pull request #2556 from BerriAI/litellm_aws_secret_manager_support
fix(utils.py): initial commit for aws secret manager support
This commit is contained in:
commit
e55a8c3570
6 changed files with 119 additions and 2 deletions
|
@ -3,7 +3,7 @@ import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any
|
from typing import Callable, List, Optional, Dict, Union, Any
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
from litellm._logging import set_verbose, _turn_on_debug, verbose_logger
|
||||||
from litellm.proxy._types import KeyManagementSystem
|
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||||
import httpx
|
import httpx
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
|
@ -187,6 +187,7 @@ secret_manager_client: Optional[Any] = (
|
||||||
)
|
)
|
||||||
_google_kms_resource_name: Optional[str] = None
|
_google_kms_resource_name: Optional[str] = None
|
||||||
_key_management_system: Optional[KeyManagementSystem] = None
|
_key_management_system: Optional[KeyManagementSystem] = None
|
||||||
|
_key_management_settings: Optional[KeyManagementSettings] = None
|
||||||
#### PII MASKING ####
|
#### PII MASKING ####
|
||||||
output_parse_pii: bool = False
|
output_parse_pii: bool = False
|
||||||
#############################################
|
#############################################
|
||||||
|
|
|
@ -387,9 +387,14 @@ class BudgetRequest(LiteLLMBase):
|
||||||
class KeyManagementSystem(enum.Enum):
|
class KeyManagementSystem(enum.Enum):
|
||||||
GOOGLE_KMS = "google_kms"
|
GOOGLE_KMS = "google_kms"
|
||||||
AZURE_KEY_VAULT = "azure_key_vault"
|
AZURE_KEY_VAULT = "azure_key_vault"
|
||||||
|
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||||
LOCAL = "local"
|
LOCAL = "local"
|
||||||
|
|
||||||
|
|
||||||
|
class KeyManagementSettings(LiteLLMBase):
|
||||||
|
hosted_keys: List
|
||||||
|
|
||||||
|
|
||||||
class TeamDefaultSettings(LiteLLMBase):
|
class TeamDefaultSettings(LiteLLMBase):
|
||||||
team_id: str
|
team_id: str
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,7 @@ from litellm.proxy.utils import (
|
||||||
_get_projected_spend_over_limit,
|
_get_projected_spend_over_limit,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -1089,6 +1090,8 @@ async def update_database(
|
||||||
existing_token_obj = await user_api_key_cache.async_get_cache(
|
existing_token_obj = await user_api_key_cache.async_get_cache(
|
||||||
key=hashed_token
|
key=hashed_token
|
||||||
)
|
)
|
||||||
|
if existing_token_obj is None:
|
||||||
|
return
|
||||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||||
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
if existing_user_obj is not None and isinstance(existing_user_obj, dict):
|
||||||
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
|
||||||
|
@ -1417,7 +1420,8 @@ async def update_cache(
|
||||||
else:
|
else:
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token)
|
||||||
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
if existing_token_obj is None:
|
||||||
|
return
|
||||||
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
if existing_token_obj.user_id != user_id: # an end-user id was passed in
|
||||||
end_user_id = user_id
|
end_user_id = user_id
|
||||||
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id]
|
||||||
|
@ -1905,8 +1909,21 @@ class ProxyConfig:
|
||||||
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
|
||||||
### LOAD FROM GOOGLE KMS ###
|
### LOAD FROM GOOGLE KMS ###
|
||||||
load_google_kms(use_google_kms=True)
|
load_google_kms(use_google_kms=True)
|
||||||
|
elif (
|
||||||
|
key_management_system
|
||||||
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value
|
||||||
|
):
|
||||||
|
### LOAD FROM AWS SECRET MANAGER ###
|
||||||
|
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid Key Management System selected")
|
raise ValueError("Invalid Key Management System selected")
|
||||||
|
key_management_settings = general_settings.get(
|
||||||
|
"key_management_settings", None
|
||||||
|
)
|
||||||
|
if key_management_settings is not None:
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
**key_management_settings
|
||||||
|
)
|
||||||
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
|
||||||
use_google_kms = general_settings.get("use_google_kms", False)
|
use_google_kms = general_settings.get("use_google_kms", False)
|
||||||
load_google_kms(use_google_kms=use_google_kms)
|
load_google_kms(use_google_kms=use_google_kms)
|
||||||
|
|
40
litellm/proxy/secret_managers/aws_secret_manager.py
Normal file
40
litellm/proxy/secret_managers/aws_secret_manager.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
"""
|
||||||
|
This is a file for the AWS Secret Manager Integration
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
* `os.environ["AWS_REGION_NAME"],
|
||||||
|
* `pip install boto3>=1.28.57`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import litellm, os
|
||||||
|
from typing import Optional
|
||||||
|
from litellm.proxy._types import KeyManagementSystem
|
||||||
|
|
||||||
|
|
||||||
|
def validate_environment():
|
||||||
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
|
||||||
|
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
||||||
|
if use_aws_secret_manager is None or use_aws_secret_manager == False:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
validate_environment()
|
||||||
|
|
||||||
|
# Create a Secrets Manager client
|
||||||
|
session = boto3.session.Session()
|
||||||
|
client = session.client(
|
||||||
|
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.secret_manager_client = client
|
||||||
|
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
24
litellm/tests/test_secret_manager.py
Normal file
24
litellm/tests/test_secret_manager.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import sys, os, uuid
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
from litellm import get_secret
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||||
|
|
||||||
|
|
||||||
|
def test_aws_secret_manager():
|
||||||
|
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||||
|
|
||||||
|
secret_val = get_secret("litellm_master_key")
|
||||||
|
|
||||||
|
print(f"secret_val: {secret_val}")
|
||||||
|
|
||||||
|
assert secret_val == "sk-1234"
|
|
@ -8288,8 +8288,10 @@ def get_secret(
|
||||||
default_value: Optional[Union[str, bool]] = None,
|
default_value: Optional[Union[str, bool]] = None,
|
||||||
):
|
):
|
||||||
key_management_system = litellm._key_management_system
|
key_management_system = litellm._key_management_system
|
||||||
|
key_management_settings = litellm._key_management_settings
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
secret_name = secret_name.replace("os.environ/", "")
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if litellm.secret_manager_client is not None:
|
if litellm.secret_manager_client is not None:
|
||||||
try:
|
try:
|
||||||
|
@ -8297,6 +8299,13 @@ def get_secret(
|
||||||
key_manager = "local"
|
key_manager = "local"
|
||||||
if key_management_system is not None:
|
if key_management_system is not None:
|
||||||
key_manager = key_management_system.value
|
key_manager = key_management_system.value
|
||||||
|
|
||||||
|
if key_management_settings is not None:
|
||||||
|
if (
|
||||||
|
secret_name not in key_management_settings.hosted_keys
|
||||||
|
): # allow user to specify which keys to check in hosted key manager
|
||||||
|
key_manager = "local"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
||||||
or type(client).__module__ + "." + type(client).__name__
|
or type(client).__module__ + "." + type(client).__name__
|
||||||
|
@ -8332,9 +8341,30 @@ def get_secret(
|
||||||
secret = response.plaintext.decode(
|
secret = response.plaintext.decode(
|
||||||
"utf-8"
|
"utf-8"
|
||||||
) # assumes the original value was encoded with utf-8
|
) # assumes the original value was encoded with utf-8
|
||||||
|
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||||
|
try:
|
||||||
|
get_secret_value_response = client.get_secret_value(
|
||||||
|
SecretId=secret_name
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"get_secret_value_response: {get_secret_value_response}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"An error occurred - {str(e)}")
|
||||||
|
# For a list of exceptions thrown, see
|
||||||
|
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# assume there is 1 secret per secret_name
|
||||||
|
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||||
|
print_verbose(f"secret_dict: {secret_dict}")
|
||||||
|
for k, v in secret_dict.items():
|
||||||
|
secret = v
|
||||||
|
print_verbose(f"secret: {secret}")
|
||||||
else: # assume the default is infisicial client
|
else: # assume the default is infisicial client
|
||||||
secret = client.get_secret(secret_name).secret_value
|
secret = client.get_secret(secret_name).secret_value
|
||||||
except Exception as e: # check if it's in os.environ
|
except Exception as e: # check if it's in os.environ
|
||||||
|
print_verbose(f"An exception occurred - {str(e)}")
|
||||||
secret = os.getenv(secret_name)
|
secret = os.getenv(secret_name)
|
||||||
try:
|
try:
|
||||||
secret_value_as_bool = ast.literal_eval(secret)
|
secret_value_as_bool = ast.literal_eval(secret)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue