fix importing AWSSecretsManagerV2

This commit is contained in:
Ishaan Jaff 2024-11-13 09:52:26 -08:00
parent 750439cd46
commit bf9e58f8ef
4 changed files with 225 additions and 32 deletions

View file

@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
ProxyConfig, ProxyConfig,
app, app,
load_aws_kms, load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault, load_from_azure_key_vault,
load_google_kms, load_google_kms,
save_worker_config, save_worker_config,
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
ProxyConfig, ProxyConfig,
app, app,
load_aws_kms, load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault, load_from_azure_key_vault,
load_google_kms, load_google_kms,
save_worker_config, save_worker_config,
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
ProxyConfig, ProxyConfig,
app, app,
load_aws_kms, load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault, load_from_azure_key_vault,
load_google_kms, load_google_kms,
save_worker_config, save_worker_config,
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
key_management_system key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
): ):
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
### LOAD FROM AWS SECRET MANAGER ### ### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True) AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value: elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True) load_aws_kms(use_aws_kms=True)
elif ( elif (

View file

@ -245,10 +245,7 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.secret_managers.aws_secret_manager import ( from litellm.secret_managers.aws_secret_manager import load_aws_kms
load_aws_kms,
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import ( from litellm.secret_managers.main import (
get_secret, get_secret,
@ -1825,8 +1822,13 @@ class ProxyConfig:
key_management_system key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
): ):
### LOAD FROM AWS SECRET MANAGER ### from litellm.secret_managers.aws_secret_manager_v2 import (
load_aws_secret_manager(use_aws_secret_manager=True) AWSSecretsManagerV2,
)
AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value: elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True) load_aws_kms(use_aws_kms=True)
elif ( elif (

View file

@ -23,28 +23,6 @@ def validate_environment():
raise ValueError("Missing required environment variable - AWS_REGION_NAME") raise ValueError("Missing required environment variable - AWS_REGION_NAME")
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
from botocore.exceptions import ClientError
validate_environment()
# Create a Secrets Manager client
session = boto3.session.Session() # type: ignore
client = session.client(
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
)
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
def load_aws_kms(use_aws_kms: Optional[bool]): def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False: if use_aws_kms is None or use_aws_kms is False:
return return

View file

@ -0,0 +1,210 @@
"""
This is a file for the AWS Secret Manager Integration
Handles Async Operations for:
- Read Secret
- Write Secret
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import ast
import base64
import os
import re
import sys
from typing import Any, Dict, Optional, Union
# Ensure project root is first in sys.path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
import asyncio
import json
import httpx
import litellm
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem
class AWSSecretsManagerV2(BaseAWSLLM):
@classmethod
def validate_environment(cls):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
@classmethod
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
"""
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
"""
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
cls.validate_environment()
litellm.secret_manager_client = cls()
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> str:
"""
Async function to read a secret from AWS Secrets Manager
"""
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
client_request_token: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to write a secret to AWS Secrets Manager
Args:
secret_name: Name of the secret
secret_value: Value to store (can be a JSON string)
description: Optional description for the secret
client_request_token: Optional unique identifier to ensure idempotency
optional_params: Additional AWS parameters
timeout: Request timeout
"""
import uuid
# Prepare the request data
data = {"Name": secret_name, "SecretString": secret_value}
if description:
data["Description"] = description
data["ClientRequestToken"] = str(uuid.uuid4())
endpoint_url, headers, body = self._prepare_request(
action="CreateSecret",
secret_name=secret_name,
secret_value=secret_value,
optional_params=optional_params,
request_data=data, # Pass the complete request data
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
def _prepare_request(
self,
action: str, # "GetSecretValue" or "PutSecretValue"
secret_name: str,
secret_value: Optional[str] = None,
optional_params: Optional[dict] = None,
request_data: Optional[dict] = None,
) -> tuple[str, Any, bytes]:
"""Prepare the AWS Secrets Manager request"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
optional_params = optional_params or {}
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
# Get endpoint
_, endpoint_url = self.get_runtime_endpoint(
api_base=None,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
# Use provided request_data if available, otherwise build default data
if request_data:
data = request_data
else:
data = {"SecretId": secret_name}
if secret_value and action == "PutSecretValue":
data["SecretString"] = secret_value
body = json.dumps(data).encode("utf-8")
headers = {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": f"secretsmanager.{action}",
}
# Sign request
request = AWSRequest(
method="POST", url=endpoint_url, data=body, headers=headers
)
SigV4Auth(
boto3_credentials_info.credentials,
"secretsmanager",
boto3_credentials_info.aws_region_name,
).add_auth(request)
prepped = request.prepare()
return endpoint_url, prepped.headers, body
# if __name__ == "__main__":
# print("loading aws secret manager v2")
# aws_secret_manager_v2 = AWSSecretsManagerV2()
# print("writing secret to aws secret manager v2")
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
# print("reading secret from aws secret manager v2")