fix sync_read_secret

This commit is contained in:
Ishaan Jaff 2024-11-13 14:21:08 -08:00
parent b9b5d60a38
commit 33dc97df93
2 changed files with 44 additions and 31 deletions

View file

@ -27,7 +27,10 @@ import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem
@ -105,19 +108,41 @@ class AWSSecretsManagerV2(BaseAWSLLM):
Done for backwards compatibility with existing codebase, since get_secret is a sync function
"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self.async_read_secret(
secret_name=secret_name,
optional_params=optional_params,
timeout=timeout,
)
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
if secret_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
"AWS_REGION",
"AWS_BEDROCK_RUNTIME_ENDPOINT",
]:
return os.getenv(secret_name)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
sync_client = _get_httpx_client(
params={"timeout": timeout},
)
try:
response = sync_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
async def async_write_secret(
self,
secret_name: str,

View file

@ -5,7 +5,7 @@ import json
import os
import sys
import traceback
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import httpx
from dotenv import load_dotenv
@ -269,25 +269,13 @@ def get_secret( # noqa: PLR0915
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
if isinstance(client, AWSSecretsManagerV2):
secret = client.sync_read_secret(secret_name=secret_name)
print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(