fix(utils.py): initial commit for aws secret manager support

This commit is contained in:
Krrish Dholakia 2024-03-16 14:37:46 -07:00
parent 2c2f322d5a
commit d8956e9255
3 changed files with 54 additions and 0 deletions

View file

@ -387,6 +387,7 @@ class BudgetRequest(LiteLLMBase):
class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager"
LOCAL = "local"

View file

@ -0,0 +1,40 @@
"""
This is a file for the AWS Secret Manager Integration
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import litellm, os
from typing import Optional
from litellm.proxy._types import KeyManagementSystem
def validate_environment():
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
if use_aws_secret_manager is None or use_aws_secret_manager == False:
return
try:
import boto3
from botocore.exceptions import ClientError
validate_environment()
# Create a Secrets Manager client
session = boto3.session.Session()
client = session.client(
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
)
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e

View file

@ -8332,6 +8332,19 @@ def get_secret(
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
except Exception as e:
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secretstring per secret_name
for k, v in get_secret_value_response.items():
secret = v
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ