fix sync_read_secret

This commit is contained in:
Ishaan Jaff 2024-11-13 14:21:08 -08:00
parent b9b5d60a38
commit 33dc97df93
2 changed files with 44 additions and 31 deletions

View file

@ -27,7 +27,10 @@ import httpx
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.custom_httpx.types import httpxSpecialProvider from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem
@ -105,19 +108,41 @@ class AWSSecretsManagerV2(BaseAWSLLM):
Done for backwards compatibility with existing codebase, since get_secret is a sync function Done for backwards compatibility with existing codebase, since get_secret is a sync function
""" """
try:
loop = asyncio.get_event_loop() # self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
except RuntimeError: if secret_name in [
loop = asyncio.new_event_loop() "AWS_ACCESS_KEY_ID",
asyncio.set_event_loop(loop) "AWS_SECRET_ACCESS_KEY",
return loop.run_until_complete( "AWS_REGION_NAME",
self.async_read_secret( "AWS_REGION",
"AWS_BEDROCK_RUNTIME_ENDPOINT",
]:
return os.getenv(secret_name)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name, secret_name=secret_name,
optional_params=optional_params, optional_params=optional_params,
timeout=timeout,
) )
sync_client = _get_httpx_client(
params={"timeout": timeout},
) )
try:
response = sync_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
async def async_write_secret( async def async_write_secret(
self, self,
secret_name: str, secret_name: str,

View file

@ -5,7 +5,7 @@ import json
import os import os
import sys import sys
import traceback import traceback
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import httpx import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
@ -269,25 +269,13 @@ def get_secret( # noqa: PLR0915
if isinstance(secret, str): if isinstance(secret, str):
secret = secret.strip() secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try: from litellm.secret_managers.aws_secret_manager_v2 import (
get_secret_value_response = client.get_secret_value( AWSSecretsManagerV2,
SecretId=secret_name
) )
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name if isinstance(client, AWSSecretsManagerV2):
secret_dict = json.loads(get_secret_value_response["SecretString"]) secret = client.sync_read_secret(secret_name=secret_name)
print_verbose(f"secret_dict: {secret_dict}") print_verbose(f"get_secret_value_response: {secret}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try: try:
secret = client.get_secret_from_google_secret_manager( secret = client.get_secret_from_google_secret_manager(