forked from phoenix/litellm-mirror
(Feat) Add support for storing virtual keys in AWS SecretManager (#6728)
* add SecretManager to httpxSpecialProvider * fix importing AWSSecretsManagerV2 * add unit testing for writing keys to AWS secret manager * use KeyManagementEventHooks for key/generated events * us event hooks for key management endpoints * working AWSSecretsManagerV2 * fix write secret to AWS secret manager on /key/generate * fix KeyManagementSettings * use tasks for key management hooks * add async_delete_secret * add test for async_delete_secret * use _delete_virtual_keys_from_secret_manager * fix test secret manager * test_key_generate_with_secret_manager_call * fix check for key_management_settings * sync_read_secret * test_aws_secret_manager * fix sync_read_secret * use helper to check when _should_read_secret_from_secret_manager * test_get_secret_with_access_mode * test - handle eol model claude-2, use claude-2.1 instead * docs AWS secret manager * fix test_read_nonexistent_secret * fix test_supports_response_schema * ci/cd run again
This commit is contained in:
parent
da84056e59
commit
f8e700064e
16 changed files with 1046 additions and 178 deletions
|
@ -1,3 +1,6 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Secret Manager
|
||||
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
|
||||
|
||||
|
@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
|
|||
```
|
||||
|
||||
2. Enable AWS Secret Manager in config.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="read_only" label="Read Keys from AWS Secret Manager">
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: os.environ/litellm_master_key
|
||||
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
||||
key_management_settings:
|
||||
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="write_only" label="Write Virtual Keys to AWS Secret Manager">
|
||||
|
||||
This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
||||
key_management_settings:
|
||||
store_virtual_keys: true
|
||||
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
3. Run proxy
|
||||
|
||||
```bash
|
||||
|
@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
|
|||
|
||||
Use encrypted keys from Google KMS on the proxy
|
||||
|
||||
### Usage with LiteLLM Proxy Server
|
||||
|
||||
## Step 1. Add keys to env
|
||||
Step 1. Add keys to env
|
||||
```
|
||||
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
||||
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
|
||||
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
|
||||
```
|
||||
|
||||
## Step 2: Update Config
|
||||
Step 2: Update Config
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
|
@ -199,7 +221,7 @@ general_settings:
|
|||
master_key: sk-1234
|
||||
```
|
||||
|
||||
## Step 3: Start + test proxy
|
||||
Step 3: Start + test proxy
|
||||
|
||||
```
|
||||
$ litellm --config /path/to/config.yaml
|
||||
|
@ -215,3 +237,17 @@ $ litellm --test
|
|||
<!--
|
||||
## .env Files
|
||||
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
|
||||
|
||||
|
||||
## All Secret Manager Settings
|
||||
|
||||
All settings related to secret management
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
key_management_system: "aws_secret_manager" # REQUIRED
|
||||
key_management_settings:
|
||||
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
|
||||
access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
|
||||
hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
|
||||
```
|
|
@ -304,7 +304,7 @@ secret_manager_client: Optional[Any] = (
|
|||
)
|
||||
_google_kms_resource_name: Optional[str] = None
|
||||
_key_management_system: Optional[KeyManagementSystem] = None
|
||||
_key_management_settings: Optional[KeyManagementSettings] = None
|
||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||
#### PII MASKING ####
|
||||
output_parse_pii: bool = False
|
||||
#############################################
|
||||
|
|
|
@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
|
|||
GuardrailCallback = "guardrail_callback"
|
||||
Caching = "caching"
|
||||
Oauth2Check = "oauth2_check"
|
||||
SecretManager = "secret_manager"
|
||||
|
|
|
@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
|
|||
|
||||
|
||||
class KeyManagementSettings(LiteLLMBase):
|
||||
hosted_keys: List
|
||||
hosted_keys: Optional[List] = None
|
||||
store_virtual_keys: Optional[bool] = False
|
||||
"""
|
||||
If True, virtual keys created by litellm will be stored in the secret manager
|
||||
"""
|
||||
|
||||
access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
|
||||
"""
|
||||
Access mode for the secret manager, when write_only will only use for writing secrets
|
||||
"""
|
||||
|
||||
|
||||
class TeamDefaultSettings(LiteLLMBase):
|
||||
|
|
267
litellm/proxy/hooks/key_management_event_hooks.py
Normal file
267
litellm/proxy/hooks/key_management_event_hooks.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from re import A
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyRequest,
|
||||
KeyManagementSystem,
|
||||
KeyRequest,
|
||||
LiteLLM_AuditLogs,
|
||||
LiteLLM_VerificationToken,
|
||||
LitellmTableNames,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
UpdateKeyRequest,
|
||||
UserAPIKeyAuth,
|
||||
WebhookEvent,
|
||||
)
|
||||
|
||||
|
||||
class KeyManagementEventHooks:
|
||||
|
||||
@staticmethod
|
||||
async def async_key_generated_hook(
|
||||
data: GenerateKeyRequest,
|
||||
response: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Hook that runs after a successful /key/generate request
|
||||
|
||||
Handles the following:
|
||||
- Sending Email with Key Details
|
||||
- Storing Audit Logs for key generation
|
||||
- Storing Generated Key in DB
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
if data.send_invite_email is True:
|
||||
await KeyManagementEventHooks._send_key_created_email(response)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(response, default=str)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=response.get("token_id", ""),
|
||||
action="created",
|
||||
updated_values=_updated_values,
|
||||
before_value=None,
|
||||
)
|
||||
)
|
||||
)
|
||||
# store the generated key in the secret manager
|
||||
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
|
||||
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
|
||||
secret_token=response.get("token", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_updated_hook(
|
||||
data: UpdateKeyRequest,
|
||||
existing_key_row: Any,
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post /key/update processing hook
|
||||
|
||||
Handles the following:
|
||||
- Storing Audit Logs for key update
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
|
||||
|
||||
_before_value = existing_key_row.json(exclude_none=True)
|
||||
_before_value = json.dumps(_before_value, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=data.key,
|
||||
action="updated",
|
||||
updated_values=_updated_values,
|
||||
before_value=_before_value,
|
||||
)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def async_key_deleted_hook(
|
||||
data: KeyRequest,
|
||||
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||
response: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post /key/delete processing hook
|
||||
|
||||
Handles the following:
|
||||
- Storing Audit Logs for key deletion
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for each team deleted
|
||||
for key in data.keys:
|
||||
key_row = await prisma_client.get_data( # type: ignore
|
||||
token=key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
if key_row is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
key_row = key_row.json(exclude_none=True)
|
||||
_key_row = json.dumps(key_row, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=key,
|
||||
action="deleted",
|
||||
updated_values="{}",
|
||||
before_value=_key_row,
|
||||
)
|
||||
)
|
||||
)
|
||||
# delete the keys from the secret manager
|
||||
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
|
||||
keys_being_deleted=keys_being_deleted
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
|
||||
"""
|
||||
Store a virtual key in the secret manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the virtual key
|
||||
secret_token: Value of the virtual key (example: sk-1234)
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
# store the key in the secret manager
|
||||
if (
|
||||
litellm._key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
|
||||
):
|
||||
await litellm.secret_manager_client.async_write_secret(
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_token,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _delete_virtual_keys_from_secret_manager(
|
||||
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||
):
|
||||
"""
|
||||
Deletes virtual keys from the secret manager
|
||||
|
||||
Args:
|
||||
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
|
||||
for key in keys_being_deleted:
|
||||
if key.key_alias is not None:
|
||||
await litellm.secret_manager_client.async_delete_secret(
|
||||
secret_name=key.key_alias
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_key_created_email(response: dict):
|
||||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
if "email" not in general_settings.get("alerting", []):
|
||||
raise ValueError(
|
||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||
)
|
||||
event = WebhookEvent(
|
||||
event="key_created",
|
||||
event_group="key",
|
||||
event_message="API Key Created",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", None),
|
||||
)
|
||||
|
||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=event,
|
||||
)
|
||||
)
|
|
@ -17,7 +17,7 @@ import secrets
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||
|
@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_key_object,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
data.soft_budget
|
||||
) # include the user-input soft budget in the response
|
||||
|
||||
if data.send_invite_email is True:
|
||||
if "email" not in general_settings.get("alerting", []):
|
||||
raise ValueError(
|
||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||
)
|
||||
event = WebhookEvent(
|
||||
event="key_created",
|
||||
event_group="key",
|
||||
event_message="API Key Created",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", None),
|
||||
)
|
||||
|
||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=event,
|
||||
)
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(response, default=str)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=response.get("token_id", ""),
|
||||
action="created",
|
||||
updated_values=_updated_values,
|
||||
before_value=None,
|
||||
)
|
||||
)
|
||||
asyncio.create_task(
|
||||
KeyManagementEventHooks.async_key_generated_hook(
|
||||
data=data,
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_changed_by=litellm_changed_by,
|
||||
)
|
||||
)
|
||||
|
||||
return GenerateKeyResponse(**response)
|
||||
except Exception as e:
|
||||
|
@ -407,30 +372,15 @@ async def update_key_fn(
|
|||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(data_json, default=str)
|
||||
|
||||
_before_value = existing_key_row.json(exclude_none=True)
|
||||
_before_value = json.dumps(_before_value, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=data.key,
|
||||
action="updated",
|
||||
updated_values=_updated_values,
|
||||
before_value=_before_value,
|
||||
)
|
||||
)
|
||||
asyncio.create_task(
|
||||
KeyManagementEventHooks.async_key_updated_hook(
|
||||
data=data,
|
||||
existing_key_row=existing_key_row,
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_changed_by=litellm_changed_by,
|
||||
)
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Failed to update key got response = None")
|
||||
|
@ -496,6 +446,9 @@ async def delete_key_fn(
|
|||
user_custom_key_generate,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
keys = data.keys
|
||||
if len(keys) == 0:
|
||||
raise ProxyException(
|
||||
|
@ -516,45 +469,7 @@ async def delete_key_fn(
|
|||
):
|
||||
user_id = None # unless they're admin
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for each team deleted
|
||||
for key in data.keys:
|
||||
key_row = await prisma_client.get_data( # type: ignore
|
||||
token=key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
if key_row is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
key_row = key_row.json(exclude_none=True)
|
||||
_key_row = json.dumps(key_row, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=key,
|
||||
action="deleted",
|
||||
updated_values="{}",
|
||||
before_value=_key_row,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
number_deleted_keys = await delete_verification_token(
|
||||
number_deleted_keys, _keys_being_deleted = await delete_verification_token(
|
||||
tokens=keys, user_id=user_id
|
||||
)
|
||||
if number_deleted_keys is None:
|
||||
|
@ -588,6 +503,16 @@ async def delete_key_fn(
|
|||
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
KeyManagementEventHooks.async_key_deleted_hook(
|
||||
data=data,
|
||||
keys_being_deleted=_keys_being_deleted,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_changed_by=litellm_changed_by,
|
||||
response=number_deleted_keys,
|
||||
)
|
||||
)
|
||||
|
||||
return {"deleted_keys": keys}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
|||
return key_data
|
||||
|
||||
|
||||
async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
|
||||
async def delete_verification_token(
|
||||
tokens: List, user_id: Optional[str] = None
|
||||
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||
"""
|
||||
Helper that deletes the list of tokens from the database
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to delete
|
||||
user_id: Optional user_id to filter by
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||
Optional[Dict]:
|
||||
- Number of deleted tokens
|
||||
List[LiteLLM_VerificationToken]:
|
||||
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
|
||||
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
try:
|
||||
if prisma_client:
|
||||
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||
_keys_being_deleted = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": tokens}}
|
||||
)
|
||||
)
|
||||
|
||||
# Assuming 'db' is your Prisma Client instance
|
||||
# check if admin making request - don't filter by user-id
|
||||
if user_id == litellm_proxy_admin_name:
|
||||
|
@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
|||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
return deleted_tokens
|
||||
return deleted_tokens, _keys_being_deleted
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
|
|||
ProxyConfig,
|
||||
app,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_from_azure_key_vault,
|
||||
load_google_kms,
|
||||
save_worker_config,
|
||||
|
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
|
|||
ProxyConfig,
|
||||
app,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_from_azure_key_vault,
|
||||
load_google_kms,
|
||||
save_worker_config,
|
||||
|
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
|
|||
ProxyConfig,
|
||||
app,
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
load_from_azure_key_vault,
|
||||
load_google_kms,
|
||||
save_worker_config,
|
||||
|
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
|
|||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
AWSSecretsManagerV2.load_aws_secret_manager(
|
||||
use_aws_secret_manager=True
|
||||
)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
elif (
|
||||
|
|
|
@ -7,6 +7,8 @@ model_list:
|
|||
|
||||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["gcs_bucket"]
|
||||
|
||||
general_settings:
|
||||
key_management_system: "aws_secret_manager"
|
||||
key_management_settings:
|
||||
store_virtual_keys: true
|
||||
access_mode: "write_only"
|
||||
|
|
|
@ -245,10 +245,7 @@ from litellm.router import (
|
|||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm.router import updateDeployment
|
||||
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
||||
from litellm.secret_managers.aws_secret_manager import (
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.secret_managers.aws_secret_manager import load_aws_kms
|
||||
from litellm.secret_managers.google_kms import load_google_kms
|
||||
from litellm.secret_managers.main import (
|
||||
get_secret,
|
||||
|
@ -1825,8 +1822,13 @@ class ProxyConfig:
|
|||
key_management_system
|
||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||
):
|
||||
### LOAD FROM AWS SECRET MANAGER ###
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
AWSSecretsManagerV2.load_aws_secret_manager(
|
||||
use_aws_secret_manager=True
|
||||
)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
elif (
|
||||
|
|
|
@ -23,28 +23,6 @@ def validate_environment():
|
|||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
|
||||
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
||||
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
session = boto3.session.Session() # type: ignore
|
||||
client = session.client(
|
||||
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
|
||||
)
|
||||
|
||||
litellm.secret_manager_client = client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
|
|
310
litellm/secret_managers/aws_secret_manager_v2.py
Normal file
310
litellm/secret_managers/aws_secret_manager_v2.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
"""
|
||||
This is a file for the AWS Secret Manager Integration
|
||||
|
||||
Handles Async Operations for:
|
||||
- Read Secret
|
||||
- Write Secret
|
||||
- Delete Secret
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||
|
||||
Requires:
|
||||
* `os.environ["AWS_REGION_NAME"],
|
||||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.custom_httpx.types import httpxSpecialProvider
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
class AWSSecretsManagerV2(BaseAWSLLM):
|
||||
@classmethod
|
||||
def validate_environment(cls):
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
@classmethod
|
||||
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
|
||||
"""
|
||||
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
"""
|
||||
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
cls.validate_environment()
|
||||
litellm.secret_manager_client = cls()
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Async function to read a secret from AWS Secrets Manager
|
||||
|
||||
Returns:
|
||||
str: Secret value
|
||||
Raises:
|
||||
ValueError: If the secret is not found or an HTTP error occurs
|
||||
"""
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="GetSecretValue",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["SecretString"]
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret from AWS Secrets Manager: %s", str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Sync function to read a secret from AWS Secrets Manager
|
||||
|
||||
Done for backwards compatibility with existing codebase, since get_secret is a sync function
|
||||
"""
|
||||
|
||||
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
|
||||
if secret_name in [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
"AWS_REGION",
|
||||
"AWS_BEDROCK_RUNTIME_ENDPOINT",
|
||||
]:
|
||||
return os.getenv(secret_name)
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="GetSecretValue",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
sync_client = _get_httpx_client(
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["SecretString"]
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret from AWS Secrets Manager: %s", str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
client_request_token: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to write a secret to AWS Secrets Manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret
|
||||
secret_value: Value to store (can be a JSON string)
|
||||
description: Optional description for the secret
|
||||
client_request_token: Optional unique identifier to ensure idempotency
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Prepare the request data
|
||||
data = {"Name": secret_name, "SecretString": secret_value}
|
||||
if description:
|
||||
data["Description"] = description
|
||||
|
||||
data["ClientRequestToken"] = str(uuid.uuid4())
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="CreateSecret",
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
optional_params=optional_params,
|
||||
request_data=data, # Pass the complete request data
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from AWS Secrets Manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response from AWS Secrets Manager containing deletion details
|
||||
"""
|
||||
# Prepare the request data
|
||||
data = {
|
||||
"SecretId": secret_name,
|
||||
"RecoveryWindowInDays": recovery_window_in_days,
|
||||
}
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="DeleteSecret",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
action: str, # "GetSecretValue" or "PutSecretValue"
|
||||
secret_name: str,
|
||||
secret_value: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
) -> tuple[str, Any, bytes]:
|
||||
"""Prepare the AWS Secrets Manager request"""
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
optional_params = optional_params or {}
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
)
|
||||
|
||||
# Get endpoint
|
||||
_, endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=None,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
|
||||
|
||||
# Use provided request_data if available, otherwise build default data
|
||||
if request_data:
|
||||
data = request_data
|
||||
else:
|
||||
data = {"SecretId": secret_name}
|
||||
if secret_value and action == "PutSecretValue":
|
||||
data["SecretString"] = secret_value
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/x-amz-json-1.1",
|
||||
"X-Amz-Target": f"secretsmanager.{action}",
|
||||
}
|
||||
|
||||
# Sign request
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=body, headers=headers
|
||||
)
|
||||
SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"secretsmanager",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
).add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
return endpoint_url, prepped.headers, body
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# print("loading aws secret manager v2")
|
||||
# aws_secret_manager_v2 = AWSSecretsManagerV2()
|
||||
|
||||
# print("writing secret to aws secret manager v2")
|
||||
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
|
||||
# print("reading secret from aws secret manager v2")
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
|
|||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
if (
|
||||
_should_read_secret_from_secret_manager()
|
||||
and litellm.secret_manager_client is not None
|
||||
):
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
|
@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
|
|||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
secret_name not in key_management_settings.hosted_keys
|
||||
key_management_settings.hosted_keys is not None
|
||||
and secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
|
@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
|
|||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
print_verbose(
|
||||
f"get_secret_value_response: {get_secret_value_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
# assume there is 1 secret per secret_name
|
||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||
print_verbose(f"secret_dict: {secret_dict}")
|
||||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
if isinstance(client, AWSSecretsManagerV2):
|
||||
secret = client.sync_read_secret(secret_name=secret_name)
|
||||
print_verbose(f"get_secret_value_response: {secret}")
|
||||
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||
try:
|
||||
secret = client.get_secret_from_google_secret_manager(
|
||||
|
@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
|
|||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def _should_read_secret_from_secret_manager() -> bool:
|
||||
"""
|
||||
Returns True if the secret manager should be used to read the secret, False otherwise
|
||||
|
||||
- If the secret manager client is not set, return False
|
||||
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
|
||||
- Otherwise, return False
|
||||
"""
|
||||
if litellm.secret_manager_client is not None:
|
||||
if litellm._key_management_settings is not None:
|
||||
if (
|
||||
litellm._key_management_settings.access_mode == "read_only"
|
||||
or litellm._key_management_settings.access_mode == "read_and_write"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
|
139
tests/local_testing/test_aws_secret_manager.py
Normal file
139
tests/local_testing/test_aws_secret_manager.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# What is this?
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ensure the project root is in the Python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
||||
|
||||
print("Python Path:", sys.path)
|
||||
print("Current Working Directory:", os.getcwd())
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
import json
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
|
||||
|
||||
def check_aws_credentials():
|
||||
"""Helper function to check if AWS credentials are set"""
|
||||
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
|
||||
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
||||
if missing_vars:
|
||||
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_simple_secret():
|
||||
"""Test writing and reading a simple string secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
|
||||
test_secret_value = "test_value_123"
|
||||
|
||||
try:
|
||||
# Write secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=test_secret_name,
|
||||
secret_value=test_secret_value,
|
||||
description="LiteLLM Test Secret",
|
||||
)
|
||||
|
||||
print("Write Response:", write_response)
|
||||
|
||||
assert write_response is not None
|
||||
assert "ARN" in write_response
|
||||
assert "Name" in write_response
|
||||
assert write_response["Name"] == test_secret_name
|
||||
|
||||
# Read secret back
|
||||
read_value = await secret_manager.async_read_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
|
||||
print("Read Value:", read_value)
|
||||
|
||||
assert read_value == test_secret_value
|
||||
finally:
|
||||
# Cleanup: Delete the secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_json_secret():
|
||||
"""Test writing and reading a JSON structured secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
|
||||
test_secret_value = {
|
||||
"api_key": "test_key",
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"metadata": {"team": "ml", "project": "litellm"},
|
||||
}
|
||||
|
||||
try:
|
||||
# Write JSON secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=test_secret_name,
|
||||
secret_value=json.dumps(test_secret_value),
|
||||
description="LiteLLM JSON Test Secret",
|
||||
)
|
||||
|
||||
print("Write Response:", write_response)
|
||||
|
||||
# Read and parse JSON secret
|
||||
read_value = await secret_manager.async_read_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
parsed_value = json.loads(read_value)
|
||||
|
||||
print("Read Value:", read_value)
|
||||
|
||||
assert parsed_value == test_secret_value
|
||||
assert parsed_value["api_key"] == "test_key"
|
||||
assert parsed_value["metadata"]["team"] == "ml"
|
||||
finally:
|
||||
# Cleanup: Delete the secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_nonexistent_secret():
|
||||
"""Test reading a secret that doesn't exist"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
|
||||
|
||||
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
|
||||
|
||||
assert response is None
|
|
@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
|||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
|
||||
# litellm.num_retries = 3
|
||||
# litellm.num_retries=3
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
|
|
|
@ -15,22 +15,29 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
from litellm.secret_managers.main import (
|
||||
get_secret,
|
||||
_should_read_secret_from_secret_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
def test_aws_secret_manager():
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
import json
|
||||
|
||||
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
|
||||
secret_val = get_secret("litellm_master_key")
|
||||
|
||||
print(f"secret_val: {secret_val}")
|
||||
|
||||
assert secret_val == "sk-1234"
|
||||
# cast json to dict
|
||||
secret_val = json.loads(secret_val)
|
||||
|
||||
assert secret_val["litellm_master_key"] == "sk-1234"
|
||||
|
||||
|
||||
def redact_oidc_signature(secret_val):
|
||||
|
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
|
|||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
assert secret_val == "lite-llm"
|
||||
|
||||
|
||||
def test_should_read_secret_from_secret_manager():
|
||||
"""
|
||||
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
|
||||
"""
|
||||
from litellm.proxy._types import KeyManagementSettings
|
||||
|
||||
# Test when secret manager client is None
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
assert _should_read_secret_from_secret_manager() is False
|
||||
|
||||
# Test with secret manager client and read_only access
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with secret manager client and read_and_write access
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
access_mode="read_and_write"
|
||||
)
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with secret manager client and write_only access
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||
assert _should_read_secret_from_secret_manager() is False
|
||||
|
||||
# Reset global variables
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
|
||||
|
||||
def test_get_secret_with_access_mode():
|
||||
"""
|
||||
Test that get_secret respects access mode settings
|
||||
"""
|
||||
from litellm.proxy._types import KeyManagementSettings
|
||||
|
||||
# Set up test environment
|
||||
test_secret_name = "TEST_SECRET_KEY"
|
||||
test_secret_value = "test_secret_value"
|
||||
os.environ[test_secret_name] = test_secret_value
|
||||
|
||||
# Test with write_only access (should read from os.environ)
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||
assert get_secret(test_secret_name) == test_secret_value
|
||||
|
||||
# Test with no KeyManagementSettings but secret_manager_client set
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with read_only access
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with read_and_write access
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
access_mode="read_and_write"
|
||||
)
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Reset global variables
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
del os.environ[test_secret_name]
|
||||
|
|
|
@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
|
|||
request=request,
|
||||
api_key="Bearer sk-123456789",
|
||||
)
|
||||
|
||||
|
||||
## E2E Virtual Key + Secret Manager Tests #########################################
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_generate_with_secret_manager_call(prisma_client):
|
||||
"""
|
||||
Generate a key
|
||||
assert it exists in the secret manager
|
||||
|
||||
delete the key
|
||||
assert it is deleted from the secret manager
|
||||
"""
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
#### Test Setup ############################################################
|
||||
aws_secret_manager_client = AWSSecretsManagerV2()
|
||||
litellm.secret_manager_client = aws_secret_manager_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
store_virtual_keys=True,
|
||||
)
|
||||
general_settings = {
|
||||
"key_management_system": "aws_secret_manager",
|
||||
"key_management_settings": {
|
||||
"store_virtual_keys": True,
|
||||
},
|
||||
}
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
############################################################################
|
||||
|
||||
# generate new key
|
||||
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
|
||||
spend = 100
|
||||
max_budget = 400
|
||||
models = ["fake-openai-endpoint"]
|
||||
new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
generated_key = new_key.key
|
||||
print(generated_key)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# read from the secret manager
|
||||
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||
|
||||
# Assert the correct key is stored in the secret manager
|
||||
print("response from AWS Secret Manager")
|
||||
print(result)
|
||||
assert result == generated_key
|
||||
|
||||
# delete the key
|
||||
await delete_key_fn(
|
||||
data=KeyRequest(keys=[generated_key]),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
|
||||
),
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Assert the key is deleted from the secret manager
|
||||
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||
assert result is None
|
||||
|
||||
# cleanup
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||
|
||||
|
||||
################################################################################
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue