(Feat) Add support for storing virtual keys in AWS SecretManager (#6728)

* add SecretManager to httpxSpecialProvider

* fix importing AWSSecretsManagerV2

* add unit testing for writing keys to AWS secret manager

* use KeyManagementEventHooks for key/generated events

* us event hooks for key management endpoints

* working AWSSecretsManagerV2

* fix write secret to AWS secret manager on /key/generate

* fix KeyManagementSettings

* use tasks for key management hooks

* add async_delete_secret

* add test for async_delete_secret

* use _delete_virtual_keys_from_secret_manager

* fix test secret manager

* test_key_generate_with_secret_manager_call

* fix check for key_management_settings

* sync_read_secret

* test_aws_secret_manager

* fix sync_read_secret

* use helper to check when _should_read_secret_from_secret_manager

* test_get_secret_with_access_mode

* test - handle eol model claude-2, use claude-2.1 instead

* docs AWS secret manager

* fix test_read_nonexistent_secret

* fix test_supports_response_schema

* ci/cd run again
This commit is contained in:
Ishaan Jaff 2024-11-14 09:25:07 -08:00 committed by GitHub
parent da84056e59
commit f8e700064e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1046 additions and 178 deletions

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Secret Manager
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
```
2. Enable AWS Secret Manager in config.
<Tabs>
<TabItem value="read_only" label="Read Keys from AWS Secret Manager">
```yaml
general_settings:
master_key: os.environ/litellm_master_key
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings:
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
```
</TabItem>
<TabItem value="write_only" label="Write Virtual Keys to AWS Secret Manager">
This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
```yaml
general_settings:
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings:
store_virtual_keys: true
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
```
</TabItem>
</Tabs>
3. Run proxy
```bash
@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
Use encrypted keys from Google KMS on the proxy
### Usage with LiteLLM Proxy Server
## Step 1. Add keys to env
Step 1. Add keys to env
```
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
```
## Step 2: Update Config
Step 2: Update Config
```yaml
general_settings:
@ -199,7 +221,7 @@ general_settings:
master_key: sk-1234
```
## Step 3: Start + test proxy
Step 3: Start + test proxy
```
$ litellm --config /path/to/config.yaml
@ -215,3 +237,17 @@ $ litellm --test
<!--
## .env Files
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
## All Secret Manager Settings
All settings related to secret management
```yaml
general_settings:
key_management_system: "aws_secret_manager" # REQUIRED
key_management_settings:
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
```

View file

@ -304,7 +304,7 @@ secret_manager_client: Optional[Any] = (
)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: Optional[KeyManagementSettings] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
#### PII MASKING ####
output_parse_pii: bool = False
#############################################

View file

@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"
SecretManager = "secret_manager"

View file

@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
class KeyManagementSettings(LiteLLMBase):
hosted_keys: List
hosted_keys: Optional[List] = None
store_virtual_keys: Optional[bool] = False
"""
If True, virtual keys created by litellm will be stored in the secret manager
"""
access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
"""
Access mode for the secret manager, when write_only will only use for writing secrets
"""
class TeamDefaultSettings(LiteLLMBase):

View file

@ -0,0 +1,267 @@
import asyncio
import json
import uuid
from datetime import datetime, timezone
from re import A
from typing import Any, List, Optional
from fastapi import status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
GenerateKeyRequest,
KeyManagementSystem,
KeyRequest,
LiteLLM_AuditLogs,
LiteLLM_VerificationToken,
LitellmTableNames,
ProxyErrorTypes,
ProxyException,
UpdateKeyRequest,
UserAPIKeyAuth,
WebhookEvent,
)
class KeyManagementEventHooks:
@staticmethod
async def async_key_generated_hook(
data: GenerateKeyRequest,
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Hook that runs after a successful /key/generate request
Handles the following:
- Sending Email with Key Details
- Storing Audit Logs for key generation
- Storing Generated Key in DB
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
if data.send_invite_email is True:
await KeyManagementEventHooks._send_key_created_email(response)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
action="created",
updated_values=_updated_values,
before_value=None,
)
)
)
# store the generated key in the secret manager
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
secret_token=response.get("token", ""),
)
@staticmethod
async def async_key_updated_hook(
data: UpdateKeyRequest,
existing_key_row: Any,
response: Any,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/update processing hook
Handles the following:
- Storing Audit Logs for key update
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
)
pass
@staticmethod
async def async_key_deleted_hook(
data: KeyRequest,
keys_being_deleted: List[LiteLLM_VerificationToken],
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/delete processing hook
Handles the following:
- Storing Audit Logs for key deletion
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
# delete the keys from the secret manager
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
keys_being_deleted=keys_being_deleted
)
pass
@staticmethod
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
"""
Store a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
# store the key in the secret manager
if (
litellm._key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER
and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
):
await litellm.secret_manager_client.async_write_secret(
secret_name=secret_name,
secret_value=secret_token,
)
@staticmethod
async def _delete_virtual_keys_from_secret_manager(
keys_being_deleted: List[LiteLLM_VerificationToken],
):
"""
Deletes virtual keys from the secret manager
Args:
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
for key in keys_being_deleted:
if key.key_alias is not None:
await litellm.secret_manager_client.async_delete_secret(
secret_name=key.key_alias
)
else:
verbose_proxy_logger.warning(
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
)
@staticmethod
async def _send_key_created_email(response: dict):
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
event = WebhookEvent(
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)

View file

@ -17,7 +17,7 @@ import secrets
import traceback
import uuid
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import List, Optional, Tuple
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
get_key_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret
@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915
data.soft_budget
) # include the user-input soft budget in the response
if data.send_invite_email is True:
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
event = WebhookEvent(
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
action="created",
updated_values=_updated_values,
before_value=None,
)
)
asyncio.create_task(
KeyManagementEventHooks.async_key_generated_hook(
data=data,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
)
return GenerateKeyResponse(**response)
except Exception as e:
@ -407,30 +372,15 @@ async def update_key_fn(
proxy_logging_obj=proxy_logging_obj,
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data_json, default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
asyncio.create_task(
KeyManagementEventHooks.async_key_updated_hook(
data=data,
existing_key_row=existing_key_row,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
)
if response is None:
raise ValueError("Failed to update key got response = None")
@ -496,6 +446,9 @@ async def delete_key_fn(
user_custom_key_generate,
)
if prisma_client is None:
raise Exception("Not connected to DB!")
keys = data.keys
if len(keys) == 0:
raise ProxyException(
@ -516,45 +469,7 @@ async def delete_key_fn(
):
user_id = None # unless they're admin
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
number_deleted_keys = await delete_verification_token(
number_deleted_keys, _keys_being_deleted = await delete_verification_token(
tokens=keys, user_id=user_id
)
if number_deleted_keys is None:
@ -588,6 +503,16 @@ async def delete_key_fn(
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
)
asyncio.create_task(
KeyManagementEventHooks.async_key_deleted_hook(
data=data,
keys_being_deleted=_keys_being_deleted,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
response=number_deleted_keys,
)
)
return {"deleted_keys": keys}
except Exception as e:
if isinstance(e, HTTPException):
@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
return key_data
async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
async def delete_verification_token(
tokens: List, user_id: Optional[str] = None
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
"""
Helper that deletes the list of tokens from the database
Args:
tokens: List of tokens to delete
user_id: Optional user_id to filter by
Returns:
Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
Optional[Dict]:
- Number of deleted tokens
List[LiteLLM_VerificationToken]:
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
)
# Assuming 'db' is your Prisma Client instance
# check if admin making request - don't filter by user-id
if user_id == litellm_proxy_admin_name:
@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
return deleted_tokens
return deleted_tokens, _keys_being_deleted
@router.post(

View file

@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (

View file

@ -7,6 +7,8 @@ model_list:
litellm_settings:
callbacks: ["gcs_bucket"]
general_settings:
key_management_system: "aws_secret_manager"
key_management_settings:
store_virtual_keys: true
access_mode: "write_only"

View file

@ -245,10 +245,7 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.secret_managers.aws_secret_manager import load_aws_kms
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import (
get_secret,
@ -1825,8 +1822,13 @@ class ProxyConfig:
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (

View file

@ -23,28 +23,6 @@ def validate_environment():
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
from botocore.exceptions import ClientError
validate_environment()
# Create a Secrets Manager client
session = boto3.session.Session() # type: ignore
client = session.client(
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
)
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return

View file

@ -0,0 +1,310 @@
"""
This is a file for the AWS Secret Manager Integration
Handles Async Operations for:
- Read Secret
- Write Secret
- Delete Secret
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import ast
import asyncio
import base64
import json
import os
import re
import sys
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem
class AWSSecretsManagerV2(BaseAWSLLM):
@classmethod
def validate_environment(cls):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
@classmethod
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
"""
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
"""
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
cls.validate_environment()
litellm.secret_manager_client = cls()
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Async function to read a secret from AWS Secrets Manager
Returns:
str: Secret value
Raises:
ValueError: If the secret is not found or an HTTP error occurs
"""
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Sync function to read a secret from AWS Secrets Manager
Done for backwards compatibility with existing codebase, since get_secret is a sync function
"""
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
if secret_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
"AWS_REGION",
"AWS_BEDROCK_RUNTIME_ENDPOINT",
]:
return os.getenv(secret_name)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
sync_client = _get_httpx_client(
params={"timeout": timeout},
)
try:
response = sync_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
client_request_token: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to write a secret to AWS Secrets Manager
Args:
secret_name: Name of the secret
secret_value: Value to store (can be a JSON string)
description: Optional description for the secret
client_request_token: Optional unique identifier to ensure idempotency
optional_params: Additional AWS parameters
timeout: Request timeout
"""
import uuid
# Prepare the request data
data = {"Name": secret_name, "SecretString": secret_value}
if description:
data["Description"] = description
data["ClientRequestToken"] = str(uuid.uuid4())
endpoint_url, headers, body = self._prepare_request(
action="CreateSecret",
secret_name=secret_name,
secret_value=secret_value,
optional_params=optional_params,
request_data=data, # Pass the complete request data
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to delete a secret from AWS Secrets Manager
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Number of days before permanent deletion (default: 7)
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response from AWS Secrets Manager containing deletion details
"""
# Prepare the request data
data = {
"SecretId": secret_name,
"RecoveryWindowInDays": recovery_window_in_days,
}
endpoint_url, headers, body = self._prepare_request(
action="DeleteSecret",
secret_name=secret_name,
optional_params=optional_params,
request_data=data,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
def _prepare_request(
self,
action: str, # "GetSecretValue" or "PutSecretValue"
secret_name: str,
secret_value: Optional[str] = None,
optional_params: Optional[dict] = None,
request_data: Optional[dict] = None,
) -> tuple[str, Any, bytes]:
"""Prepare the AWS Secrets Manager request"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
optional_params = optional_params or {}
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
# Get endpoint
_, endpoint_url = self.get_runtime_endpoint(
api_base=None,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
# Use provided request_data if available, otherwise build default data
if request_data:
data = request_data
else:
data = {"SecretId": secret_name}
if secret_value and action == "PutSecretValue":
data["SecretString"] = secret_value
body = json.dumps(data).encode("utf-8")
headers = {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": f"secretsmanager.{action}",
}
# Sign request
request = AWSRequest(
method="POST", url=endpoint_url, data=body, headers=headers
)
SigV4Auth(
boto3_credentials_info.credentials,
"secretsmanager",
boto3_credentials_info.aws_region_name,
).add_auth(request)
prepped = request.prepare()
return endpoint_url, prepped.headers, body
# if __name__ == "__main__":
# print("loading aws secret manager v2")
# aws_secret_manager_v2 = AWSSecretsManagerV2()
# print("writing secret to aws secret manager v2")
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
# print("reading secret from aws secret manager v2")

View file

@ -5,7 +5,7 @@ import json
import os
import sys
import traceback
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import httpx
from dotenv import load_dotenv
@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
if (
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
try:
client = litellm.secret_manager_client
key_manager = "local"
@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
key_management_settings.hosted_keys is not None
and secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
if isinstance(client, AWSSecretsManagerV2):
secret = client.sync_read_secret(secret_name=secret_name)
print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(
@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
return default_value
else:
raise e
def _should_read_secret_from_secret_manager() -> bool:
"""
Returns True if the secret manager should be used to read the secret, False otherwise
- If the secret manager client is not set, return False
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
- Otherwise, return False
"""
if litellm.secret_manager_client is not None:
if litellm._key_management_settings is not None:
if (
litellm._key_management_settings.access_mode == "read_only"
or litellm._key_management_settings.access_mode == "read_and_write"
):
return True
return False

View file

@ -0,0 +1,139 @@
# What is this?
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io
import sys
import os
# Ensure the project root is in the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
print("Python Path:", sys.path)
print("Current Working Directory:", os.getcwd())
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import uuid
import json
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
def check_aws_credentials():
"""Helper function to check if AWS credentials are set"""
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
@pytest.mark.asyncio
async def test_write_and_read_simple_secret():
"""Test writing and reading a simple string secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
test_secret_value = "test_value_123"
try:
# Write secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=test_secret_value,
description="LiteLLM Test Secret",
)
print("Write Response:", write_response)
assert write_response is not None
assert "ARN" in write_response
assert "Name" in write_response
assert write_response["Name"] == test_secret_name
# Read secret back
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
print("Read Value:", read_value)
assert read_value == test_secret_value
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_write_and_read_json_secret():
"""Test writing and reading a JSON structured secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
test_secret_value = {
"api_key": "test_key",
"model": "gpt-4",
"temperature": 0.7,
"metadata": {"team": "ml", "project": "litellm"},
}
try:
# Write JSON secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=json.dumps(test_secret_value),
description="LiteLLM JSON Test Secret",
)
print("Write Response:", write_response)
# Read and parse JSON secret
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
parsed_value = json.loads(read_value)
print("Read Value:", read_value)
assert parsed_value == test_secret_value
assert parsed_value["api_key"] == "test_key"
assert parsed_value["metadata"]["team"] == "ml"
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_read_nonexistent_secret():
"""Test reading a secret that doesn't exist"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
assert response is None

View file

@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries = 3
# litellm.num_retries=3
litellm.cache = None
litellm.success_callback = []

View file

@ -15,22 +15,29 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.secret_managers.main import (
get_secret,
_should_read_secret_from_secret_manager,
)
@pytest.mark.skip(reason="AWS Suspended Account")
def test_aws_secret_manager():
load_aws_secret_manager(use_aws_secret_manager=True)
import json
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234"
# cast json to dict
secret_val = json.loads(secret_val)
assert secret_val["litellm_master_key"] == "sk-1234"
def redact_oidc_signature(secret_val):
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
)
print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm"
def test_should_read_secret_from_secret_manager():
"""
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
"""
from litellm.proxy._types import KeyManagementSettings
# Test when secret manager client is None
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is False
# Test with secret manager client and read_only access
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and write_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert _should_read_secret_from_secret_manager() is False
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
def test_get_secret_with_access_mode():
"""
Test that get_secret respects access mode settings
"""
from litellm.proxy._types import KeyManagementSettings
# Set up test environment
test_secret_name = "TEST_SECRET_KEY"
test_secret_value = "test_secret_value"
os.environ[test_secret_name] = test_secret_value
# Test with write_only access (should read from os.environ)
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert get_secret(test_secret_name) == test_secret_value
# Test with no KeyManagementSettings but secret_manager_client set
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is True
# Test with read_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
del os.environ[test_secret_name]

View file

@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
request=request,
api_key="Bearer sk-123456789",
)
## E2E Virtual Key + Secret Manager Tests #########################################
@pytest.mark.asyncio
async def test_key_generate_with_secret_manager_call(prisma_client):
"""
Generate a key
assert it exists in the secret manager
delete the key
assert it is deleted from the secret manager
"""
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
litellm.set_verbose = True
#### Test Setup ############################################################
aws_secret_manager_client = AWSSecretsManagerV2()
litellm.secret_manager_client = aws_secret_manager_client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
litellm._key_management_settings = KeyManagementSettings(
store_virtual_keys=True,
)
general_settings = {
"key_management_system": "aws_secret_manager",
"key_management_settings": {
"store_virtual_keys": True,
},
}
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
############################################################################
# generate new key
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
spend = 100
max_budget = 400
models = ["fake-openai-endpoint"]
new_key = await generate_key_fn(
data=GenerateKeyRequest(
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = new_key.key
print(generated_key)
await asyncio.sleep(2)
# read from the secret manager
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
# Assert the correct key is stored in the secret manager
print("response from AWS Secret Manager")
print(result)
assert result == generated_key
# delete the key
await delete_key_fn(
data=KeyRequest(keys=[generated_key]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
),
)
await asyncio.sleep(2)
# Assert the key is deleted from the secret manager
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
assert result is None
# cleanup
setattr(litellm.proxy.proxy_server, "general_settings", {})
################################################################################