diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index 794678c85..69add6f23 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -27,7 +27,10 @@ import httpx import litellm from litellm._logging import verbose_logger from litellm.llms.base_aws_llm import BaseAWSLLM -from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) from litellm.llms.custom_httpx.types import httpxSpecialProvider from litellm.proxy._types import KeyManagementSystem @@ -105,19 +108,41 @@ class AWSSecretsManagerV2(BaseAWSLLM): Done for backwards compatibility with existing codebase, since get_secret is a sync function """ - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( - self.async_read_secret( - secret_name=secret_name, - optional_params=optional_params, - timeout=timeout, - ) + + # self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop + if secret_name in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION_NAME", + "AWS_REGION", + "AWS_BEDROCK_RUNTIME_ENDPOINT", + ]: + return os.getenv(secret_name) + + endpoint_url, headers, body = self._prepare_request( + action="GetSecretValue", + secret_name=secret_name, + optional_params=optional_params, ) + sync_client = _get_httpx_client( + params={"timeout": timeout}, + ) + + try: + response = sync_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json()["SecretString"] + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + except Exception as e: + verbose_logger.exception( + "Error reading secret from AWS Secrets Manager: %s", str(e) + ) + return None + async def async_write_secret( self, secret_name: str, diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index 6a3bd3d02..35274092c 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -5,7 +5,7 @@ import json import os import sys import traceback -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import httpx from dotenv import load_dotenv @@ -269,25 +269,13 @@ def get_secret( # noqa: PLR0915 if isinstance(secret, str): secret = secret.strip() elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: - try: - get_secret_value_response = client.get_secret_value( - SecretId=secret_name - ) - print_verbose( - f"get_secret_value_response: {get_secret_value_response}" - ) - except Exception as e: - print_verbose(f"An error occurred - {str(e)}") - # For a list of exceptions thrown, see - # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html - raise e + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) - # assume there is 1 secret per secret_name - secret_dict = json.loads(get_secret_value_response["SecretString"]) - print_verbose(f"secret_dict: {secret_dict}") - for k, v in secret_dict.items(): - secret = v - print_verbose(f"secret: {secret}") + if isinstance(client, AWSSecretsManagerV2): + secret = client.sync_read_secret(secret_name=secret_name) + print_verbose(f"get_secret_value_response: {secret}") elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: try: secret = client.get_secret_from_google_secret_manager(