From bf9e58f8efea68518efeeecd367101f64a1782ca Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 13 Nov 2024 09:52:26 -0800 Subject: [PATCH] fix importing AWSSecretsManagerV2 --- litellm/proxy/proxy_cli.py | 11 +- litellm/proxy/proxy_server.py | 14 +- litellm/secret_managers/aws_secret_manager.py | 22 -- .../secret_managers/aws_secret_manager_v2.py | 210 ++++++++++++++++++ 4 files changed, 225 insertions(+), 32 deletions(-) create mode 100644 litellm/secret_managers/aws_secret_manager_v2.py diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index f9f8276c7..094828de1 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915 key_management_system == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + ### LOAD FROM AWS SECRET MANAGER ### - load_aws_secret_manager(use_aws_secret_manager=True) + AWSSecretsManagerV2.load_aws_secret_manager( + use_aws_secret_manager=True + ) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) elif ( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c9c6af77f..34ac51481 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -245,10 +245,7 @@ from litellm.router import ( from litellm.router import ModelInfo as RouterModelInfo from litellm.router import updateDeployment from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler -from litellm.secret_managers.aws_secret_manager import ( - load_aws_kms, - load_aws_secret_manager, -) +from litellm.secret_managers.aws_secret_manager import load_aws_kms from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.main import ( get_secret, @@ -1825,8 +1822,13 @@ class ProxyConfig: key_management_system == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): - ### LOAD FROM AWS SECRET MANAGER ### - load_aws_secret_manager(use_aws_secret_manager=True) + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + AWSSecretsManagerV2.load_aws_secret_manager( + use_aws_secret_manager=True + ) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) elif ( diff --git a/litellm/secret_managers/aws_secret_manager.py b/litellm/secret_managers/aws_secret_manager.py index f0e510fa8..fbe951e64 100644 --- a/litellm/secret_managers/aws_secret_manager.py +++ b/litellm/secret_managers/aws_secret_manager.py @@ -23,28 +23,6 @@ def validate_environment(): raise ValueError("Missing required environment variable - AWS_REGION_NAME") -def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]): - if use_aws_secret_manager is None or use_aws_secret_manager is False: - return - try: - import boto3 - from botocore.exceptions import ClientError - - validate_environment() - - # Create a Secrets Manager client - session = boto3.session.Session() # type: ignore - client = session.client( - service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME") - ) - - litellm.secret_manager_client = client - litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER - - except Exception as e: - raise e - - def load_aws_kms(use_aws_kms: Optional[bool]): if use_aws_kms is None or use_aws_kms is False: return diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py new file mode 100644 index 000000000..1e54a6399 --- /dev/null +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -0,0 +1,210 @@ +""" +This is a file for the AWS Secret Manager Integration + +Handles Async Operations for: +- Read Secret +- Write Secret + +Relevant issue: https://github.com/BerriAI/litellm/issues/1883 + +Requires: +* `os.environ["AWS_REGION_NAME"], +* `pip install boto3>=1.28.57` +""" + +import ast +import base64 +import os +import re +import sys +from typing import Any, Dict, Optional, Union + +# Ensure project root is first in sys.path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + + +import asyncio +import json + +import httpx + +import litellm +from litellm.llms.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.llms.custom_httpx.types import httpxSpecialProvider +from litellm.proxy._types import KeyManagementSystem + + +class AWSSecretsManagerV2(BaseAWSLLM): + @classmethod + def validate_environment(cls): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + @classmethod + def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]): + """ + Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + """ + if use_aws_secret_manager is None or use_aws_secret_manager is False: + return + try: + import boto3 + + cls.validate_environment() + litellm.secret_manager_client = cls() + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + + except Exception as e: + raise e + + async def async_read_secret( + self, + secret_name: str, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> str: + """ + Async function to read a secret from AWS Secrets Manager + """ + endpoint_url, headers, body = self._prepare_request( + action="GetSecretValue", + secret_name=secret_name, + optional_params=optional_params, + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json()["SecretString"] + except httpx.HTTPStatusError as err: + raise ValueError(f"HTTP error occurred: {err.response.text}") + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + + async def async_write_secret( + self, + secret_name: str, + secret_value: str, + description: Optional[str] = None, + client_request_token: Optional[str] = None, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> dict: + """ + Async function to write a secret to AWS Secrets Manager + + Args: + secret_name: Name of the secret + secret_value: Value to store (can be a JSON string) + description: Optional description for the secret + client_request_token: Optional unique identifier to ensure idempotency + optional_params: Additional AWS parameters + timeout: Request timeout + """ + import uuid + + # Prepare the request data + data = {"Name": secret_name, "SecretString": secret_value} + if description: + data["Description"] = description + + data["ClientRequestToken"] = str(uuid.uuid4()) + + endpoint_url, headers, body = self._prepare_request( + action="CreateSecret", + secret_name=secret_name, + secret_value=secret_value, + optional_params=optional_params, + request_data=data, # Pass the complete request data + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as err: + raise ValueError(f"HTTP error occurred: {err.response.text}") + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + + def _prepare_request( + self, + action: str, # "GetSecretValue" or "PutSecretValue" + secret_name: str, + secret_value: Optional[str] = None, + optional_params: Optional[dict] = None, + request_data: Optional[dict] = None, + ) -> tuple[str, Any, bytes]: + """Prepare the AWS Secrets Manager request""" + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + optional_params = optional_params or {} + boto3_credentials_info = self._get_boto_credentials_from_optional_params( + optional_params + ) + + # Get endpoint + _, endpoint_url = self.get_runtime_endpoint( + api_base=None, + aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, + aws_region_name=boto3_credentials_info.aws_region_name, + ) + endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager") + + # Use provided request_data if available, otherwise build default data + if request_data: + data = request_data + else: + data = {"SecretId": secret_name} + if secret_value and action == "PutSecretValue": + data["SecretString"] = secret_value + + body = json.dumps(data).encode("utf-8") + headers = { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": f"secretsmanager.{action}", + } + + # Sign request + request = AWSRequest( + method="POST", url=endpoint_url, data=body, headers=headers + ) + SigV4Auth( + boto3_credentials_info.credentials, + "secretsmanager", + boto3_credentials_info.aws_region_name, + ).add_auth(request) + prepped = request.prepare() + + return endpoint_url, prepped.headers, body + + +# if __name__ == "__main__": +# print("loading aws secret manager v2") +# aws_secret_manager_v2 = AWSSecretsManagerV2() + +# print("writing secret to aws secret manager v2") +# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2")) +# print("reading secret from aws secret manager v2")