forked from phoenix/litellm-mirror
(Feat) Add support for storing virtual keys in AWS SecretManager (#6728)
* add SecretManager to httpxSpecialProvider * fix importing AWSSecretsManagerV2 * add unit testing for writing keys to AWS secret manager * use KeyManagementEventHooks for key/generated events * us event hooks for key management endpoints * working AWSSecretsManagerV2 * fix write secret to AWS secret manager on /key/generate * fix KeyManagementSettings * use tasks for key management hooks * add async_delete_secret * add test for async_delete_secret * use _delete_virtual_keys_from_secret_manager * fix test secret manager * test_key_generate_with_secret_manager_call * fix check for key_management_settings * sync_read_secret * test_aws_secret_manager * fix sync_read_secret * use helper to check when _should_read_secret_from_secret_manager * test_get_secret_with_access_mode * test - handle eol model claude-2, use claude-2.1 instead * docs AWS secret manager * fix test_read_nonexistent_secret * fix test_supports_response_schema * ci/cd run again
This commit is contained in:
parent
da84056e59
commit
f8e700064e
16 changed files with 1046 additions and 178 deletions
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Secret Manager
|
# Secret Manager
|
||||||
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
|
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
|
||||||
|
|
||||||
|
@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Enable AWS Secret Manager in config.
|
2. Enable AWS Secret Manager in config.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="read_only" label="Read Keys from AWS Secret Manager">
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: os.environ/litellm_master_key
|
master_key: os.environ/litellm_master_key
|
||||||
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
||||||
key_management_settings:
|
key_management_settings:
|
||||||
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
|
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="write_only" label="Write Virtual Keys to AWS Secret Manager">
|
||||||
|
|
||||||
|
This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
||||||
|
key_management_settings:
|
||||||
|
store_virtual_keys: true
|
||||||
|
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
3. Run proxy
|
3. Run proxy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
Use encrypted keys from Google KMS on the proxy
|
Use encrypted keys from Google KMS on the proxy
|
||||||
|
|
||||||
### Usage with LiteLLM Proxy Server
|
Step 1. Add keys to env
|
||||||
|
|
||||||
## Step 1. Add keys to env
|
|
||||||
```
|
```
|
||||||
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
||||||
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
|
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
|
||||||
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
|
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 2: Update Config
|
Step 2: Update Config
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
|
@ -199,7 +221,7 @@ general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 3: Start + test proxy
|
Step 3: Start + test proxy
|
||||||
|
|
||||||
```
|
```
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
|
@ -215,3 +237,17 @@ $ litellm --test
|
||||||
<!--
|
<!--
|
||||||
## .env Files
|
## .env Files
|
||||||
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
|
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
|
||||||
|
|
||||||
|
|
||||||
|
## All Secret Manager Settings
|
||||||
|
|
||||||
|
All settings related to secret management
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
key_management_system: "aws_secret_manager" # REQUIRED
|
||||||
|
key_management_settings:
|
||||||
|
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
|
||||||
|
access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
|
||||||
|
hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
|
||||||
|
```
|
|
@ -304,7 +304,7 @@ secret_manager_client: Optional[Any] = (
|
||||||
)
|
)
|
||||||
_google_kms_resource_name: Optional[str] = None
|
_google_kms_resource_name: Optional[str] = None
|
||||||
_key_management_system: Optional[KeyManagementSystem] = None
|
_key_management_system: Optional[KeyManagementSystem] = None
|
||||||
_key_management_settings: Optional[KeyManagementSettings] = None
|
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||||
#### PII MASKING ####
|
#### PII MASKING ####
|
||||||
output_parse_pii: bool = False
|
output_parse_pii: bool = False
|
||||||
#############################################
|
#############################################
|
||||||
|
|
|
@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
|
||||||
GuardrailCallback = "guardrail_callback"
|
GuardrailCallback = "guardrail_callback"
|
||||||
Caching = "caching"
|
Caching = "caching"
|
||||||
Oauth2Check = "oauth2_check"
|
Oauth2Check = "oauth2_check"
|
||||||
|
SecretManager = "secret_manager"
|
||||||
|
|
|
@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class KeyManagementSettings(LiteLLMBase):
|
class KeyManagementSettings(LiteLLMBase):
|
||||||
hosted_keys: List
|
hosted_keys: Optional[List] = None
|
||||||
|
store_virtual_keys: Optional[bool] = False
|
||||||
|
"""
|
||||||
|
If True, virtual keys created by litellm will be stored in the secret manager
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
|
||||||
|
"""
|
||||||
|
Access mode for the secret manager, when write_only will only use for writing secrets
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TeamDefaultSettings(LiteLLMBase):
|
class TeamDefaultSettings(LiteLLMBase):
|
||||||
|
|
267
litellm/proxy/hooks/key_management_event_hooks.py
Normal file
267
litellm/proxy/hooks/key_management_event_hooks.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from re import A
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
GenerateKeyRequest,
|
||||||
|
KeyManagementSystem,
|
||||||
|
KeyRequest,
|
||||||
|
LiteLLM_AuditLogs,
|
||||||
|
LiteLLM_VerificationToken,
|
||||||
|
LitellmTableNames,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
UpdateKeyRequest,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
WebhookEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyManagementEventHooks:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_key_generated_hook(
|
||||||
|
data: GenerateKeyRequest,
|
||||||
|
response: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_changed_by: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hook that runs after a successful /key/generate request
|
||||||
|
|
||||||
|
Handles the following:
|
||||||
|
- Sending Email with Key Details
|
||||||
|
- Storing Audit Logs for key generation
|
||||||
|
- Storing Generated Key in DB
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_helpers.audit_logs import (
|
||||||
|
create_audit_log_for_update,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
general_settings,
|
||||||
|
litellm_proxy_admin_name,
|
||||||
|
proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.send_invite_email is True:
|
||||||
|
await KeyManagementEventHooks._send_key_created_email(response)
|
||||||
|
|
||||||
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
if litellm.store_audit_logs is True:
|
||||||
|
_updated_values = json.dumps(response, default=str)
|
||||||
|
asyncio.create_task(
|
||||||
|
create_audit_log_for_update(
|
||||||
|
request_data=LiteLLM_AuditLogs(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||||
|
object_id=response.get("token_id", ""),
|
||||||
|
action="created",
|
||||||
|
updated_values=_updated_values,
|
||||||
|
before_value=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# store the generated key in the secret manager
|
||||||
|
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
|
||||||
|
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
|
||||||
|
secret_token=response.get("token", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_key_updated_hook(
|
||||||
|
data: UpdateKeyRequest,
|
||||||
|
existing_key_row: Any,
|
||||||
|
response: Any,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_changed_by: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Post /key/update processing hook
|
||||||
|
|
||||||
|
Handles the following:
|
||||||
|
- Storing Audit Logs for key update
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_helpers.audit_logs import (
|
||||||
|
create_audit_log_for_update,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||||
|
|
||||||
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
if litellm.store_audit_logs is True:
|
||||||
|
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
|
||||||
|
|
||||||
|
_before_value = existing_key_row.json(exclude_none=True)
|
||||||
|
_before_value = json.dumps(_before_value, default=str)
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
create_audit_log_for_update(
|
||||||
|
request_data=LiteLLM_AuditLogs(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||||
|
object_id=data.key,
|
||||||
|
action="updated",
|
||||||
|
updated_values=_updated_values,
|
||||||
|
before_value=_before_value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_key_deleted_hook(
|
||||||
|
data: KeyRequest,
|
||||||
|
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||||
|
response: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_changed_by: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Post /key/delete processing hook
|
||||||
|
|
||||||
|
Handles the following:
|
||||||
|
- Storing Audit Logs for key deletion
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_helpers.audit_logs import (
|
||||||
|
create_audit_log_for_update,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||||
|
|
||||||
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||||
|
if litellm.store_audit_logs is True:
|
||||||
|
# make an audit log for each team deleted
|
||||||
|
for key in data.keys:
|
||||||
|
key_row = await prisma_client.get_data( # type: ignore
|
||||||
|
token=key, table_name="key", query_type="find_unique"
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_row is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message=f"Key {key} not found",
|
||||||
|
type=ProxyErrorTypes.bad_request_error,
|
||||||
|
param="key",
|
||||||
|
code=status.HTTP_404_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
key_row = key_row.json(exclude_none=True)
|
||||||
|
_key_row = json.dumps(key_row, default=str)
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
create_audit_log_for_update(
|
||||||
|
request_data=LiteLLM_AuditLogs(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||||
|
object_id=key,
|
||||||
|
action="deleted",
|
||||||
|
updated_values="{}",
|
||||||
|
before_value=_key_row,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# delete the keys from the secret manager
|
||||||
|
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
|
||||||
|
keys_being_deleted=keys_being_deleted
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
|
||||||
|
"""
|
||||||
|
Store a virtual key in the secret manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_name: Name of the virtual key
|
||||||
|
secret_token: Value of the virtual key (example: sk-1234)
|
||||||
|
"""
|
||||||
|
if litellm._key_management_settings is not None:
|
||||||
|
if litellm._key_management_settings.store_virtual_keys is True:
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# store the key in the secret manager
|
||||||
|
if (
|
||||||
|
litellm._key_management_system
|
||||||
|
== KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
|
||||||
|
):
|
||||||
|
await litellm.secret_manager_client.async_write_secret(
|
||||||
|
secret_name=secret_name,
|
||||||
|
secret_value=secret_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _delete_virtual_keys_from_secret_manager(
|
||||||
|
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deletes virtual keys from the secret manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
|
||||||
|
"""
|
||||||
|
if litellm._key_management_settings is not None:
|
||||||
|
if litellm._key_management_settings.store_virtual_keys is True:
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
|
||||||
|
for key in keys_being_deleted:
|
||||||
|
if key.key_alias is not None:
|
||||||
|
await litellm.secret_manager_client.async_delete_secret(
|
||||||
|
secret_name=key.key_alias
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_key_created_email(response: dict):
|
||||||
|
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||||
|
|
||||||
|
if "email" not in general_settings.get("alerting", []):
|
||||||
|
raise ValueError(
|
||||||
|
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||||
|
)
|
||||||
|
event = WebhookEvent(
|
||||||
|
event="key_created",
|
||||||
|
event_group="key",
|
||||||
|
event_message="API Key Created",
|
||||||
|
token=response.get("token", ""),
|
||||||
|
spend=response.get("spend", 0.0),
|
||||||
|
max_budget=response.get("max_budget", 0.0),
|
||||||
|
user_id=response.get("user_id", None),
|
||||||
|
team_id=response.get("team_id", "Default Team"),
|
||||||
|
key_alias=response.get("key_alias", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||||
|
webhook_event=event,
|
||||||
|
)
|
||||||
|
)
|
|
@ -17,7 +17,7 @@ import secrets
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||||
|
@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_key_object,
|
get_key_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||||
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
data.soft_budget
|
data.soft_budget
|
||||||
) # include the user-input soft budget in the response
|
) # include the user-input soft budget in the response
|
||||||
|
|
||||||
if data.send_invite_email is True:
|
asyncio.create_task(
|
||||||
if "email" not in general_settings.get("alerting", []):
|
KeyManagementEventHooks.async_key_generated_hook(
|
||||||
raise ValueError(
|
data=data,
|
||||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
response=response,
|
||||||
)
|
user_api_key_dict=user_api_key_dict,
|
||||||
event = WebhookEvent(
|
litellm_changed_by=litellm_changed_by,
|
||||||
event="key_created",
|
|
||||||
event_group="key",
|
|
||||||
event_message="API Key Created",
|
|
||||||
token=response.get("token", ""),
|
|
||||||
spend=response.get("spend", 0.0),
|
|
||||||
max_budget=response.get("max_budget", 0.0),
|
|
||||||
user_id=response.get("user_id", None),
|
|
||||||
team_id=response.get("team_id", "Default Team"),
|
|
||||||
key_alias=response.get("key_alias", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
|
||||||
webhook_event=event,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
|
||||||
if litellm.store_audit_logs is True:
|
|
||||||
_updated_values = json.dumps(response, default=str)
|
|
||||||
asyncio.create_task(
|
|
||||||
create_audit_log_for_update(
|
|
||||||
request_data=LiteLLM_AuditLogs(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
changed_by=litellm_changed_by
|
|
||||||
or user_api_key_dict.user_id
|
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
|
||||||
object_id=response.get("token_id", ""),
|
|
||||||
action="created",
|
|
||||||
updated_values=_updated_values,
|
|
||||||
before_value=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return GenerateKeyResponse(**response)
|
return GenerateKeyResponse(**response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -407,30 +372,15 @@ async def update_key_fn(
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
asyncio.create_task(
|
||||||
if litellm.store_audit_logs is True:
|
KeyManagementEventHooks.async_key_updated_hook(
|
||||||
_updated_values = json.dumps(data_json, default=str)
|
data=data,
|
||||||
|
existing_key_row=existing_key_row,
|
||||||
_before_value = existing_key_row.json(exclude_none=True)
|
response=response,
|
||||||
_before_value = json.dumps(_before_value, default=str)
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
litellm_changed_by=litellm_changed_by,
|
||||||
asyncio.create_task(
|
|
||||||
create_audit_log_for_update(
|
|
||||||
request_data=LiteLLM_AuditLogs(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
changed_by=litellm_changed_by
|
|
||||||
or user_api_key_dict.user_id
|
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
|
||||||
object_id=data.key,
|
|
||||||
action="updated",
|
|
||||||
updated_values=_updated_values,
|
|
||||||
before_value=_before_value,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
raise ValueError("Failed to update key got response = None")
|
raise ValueError("Failed to update key got response = None")
|
||||||
|
@ -496,6 +446,9 @@ async def delete_key_fn(
|
||||||
user_custom_key_generate,
|
user_custom_key_generate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
keys = data.keys
|
keys = data.keys
|
||||||
if len(keys) == 0:
|
if len(keys) == 0:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -516,45 +469,7 @@ async def delete_key_fn(
|
||||||
):
|
):
|
||||||
user_id = None # unless they're admin
|
user_id = None # unless they're admin
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
number_deleted_keys, _keys_being_deleted = await delete_verification_token(
|
||||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
|
||||||
if litellm.store_audit_logs is True:
|
|
||||||
# make an audit log for each team deleted
|
|
||||||
for key in data.keys:
|
|
||||||
key_row = await prisma_client.get_data( # type: ignore
|
|
||||||
token=key, table_name="key", query_type="find_unique"
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_row is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message=f"Key {key} not found",
|
|
||||||
type=ProxyErrorTypes.bad_request_error,
|
|
||||||
param="key",
|
|
||||||
code=status.HTTP_404_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
key_row = key_row.json(exclude_none=True)
|
|
||||||
_key_row = json.dumps(key_row, default=str)
|
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
create_audit_log_for_update(
|
|
||||||
request_data=LiteLLM_AuditLogs(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
changed_by=litellm_changed_by
|
|
||||||
or user_api_key_dict.user_id
|
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
|
||||||
object_id=key,
|
|
||||||
action="deleted",
|
|
||||||
updated_values="{}",
|
|
||||||
before_value=_key_row,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
number_deleted_keys = await delete_verification_token(
|
|
||||||
tokens=keys, user_id=user_id
|
tokens=keys, user_id=user_id
|
||||||
)
|
)
|
||||||
if number_deleted_keys is None:
|
if number_deleted_keys is None:
|
||||||
|
@ -588,6 +503,16 @@ async def delete_key_fn(
|
||||||
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
|
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
KeyManagementEventHooks.async_key_deleted_hook(
|
||||||
|
data=data,
|
||||||
|
keys_being_deleted=_keys_being_deleted,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
litellm_changed_by=litellm_changed_by,
|
||||||
|
response=number_deleted_keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return {"deleted_keys": keys}
|
return {"deleted_keys": keys}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
||||||
return key_data
|
return key_data
|
||||||
|
|
||||||
|
|
||||||
async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
|
async def delete_verification_token(
|
||||||
|
tokens: List, user_id: Optional[str] = None
|
||||||
|
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||||
|
"""
|
||||||
|
Helper that deletes the list of tokens from the database
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: List of tokens to delete
|
||||||
|
user_id: Optional user_id to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||||
|
Optional[Dict]:
|
||||||
|
- Number of deleted tokens
|
||||||
|
List[LiteLLM_VerificationToken]:
|
||||||
|
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
|
||||||
|
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
|
||||||
|
"""
|
||||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
|
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||||
|
_keys_being_deleted = (
|
||||||
|
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||||
|
where={"token": {"in": tokens}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Assuming 'db' is your Prisma Client instance
|
# Assuming 'db' is your Prisma Client instance
|
||||||
# check if admin making request - don't filter by user-id
|
# check if admin making request - don't filter by user-id
|
||||||
if user_id == litellm_proxy_admin_name:
|
if user_id == litellm_proxy_admin_name:
|
||||||
|
@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
verbose_proxy_logger.debug(traceback.format_exc())
|
||||||
raise e
|
raise e
|
||||||
return deleted_tokens
|
return deleted_tokens, _keys_being_deleted
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
|
@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
app,
|
app,
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
|
||||||
load_from_azure_key_vault,
|
load_from_azure_key_vault,
|
||||||
load_google_kms,
|
load_google_kms,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
|
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
app,
|
app,
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
|
||||||
load_from_azure_key_vault,
|
load_from_azure_key_vault,
|
||||||
load_google_kms,
|
load_google_kms,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
|
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
app,
|
app,
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
|
||||||
load_from_azure_key_vault,
|
load_from_azure_key_vault,
|
||||||
load_google_kms,
|
load_google_kms,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
|
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
|
||||||
key_management_system
|
key_management_system
|
||||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||||
):
|
):
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
### LOAD FROM AWS SECRET MANAGER ###
|
### LOAD FROM AWS SECRET MANAGER ###
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
AWSSecretsManagerV2.load_aws_secret_manager(
|
||||||
|
use_aws_secret_manager=True
|
||||||
|
)
|
||||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||||
load_aws_kms(use_aws_kms=True)
|
load_aws_kms(use_aws_kms=True)
|
||||||
elif (
|
elif (
|
||||||
|
|
|
@ -7,6 +7,8 @@ model_list:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
general_settings:
|
||||||
callbacks: ["gcs_bucket"]
|
key_management_system: "aws_secret_manager"
|
||||||
|
key_management_settings:
|
||||||
|
store_virtual_keys: true
|
||||||
|
access_mode: "write_only"
|
||||||
|
|
|
@ -245,10 +245,7 @@ from litellm.router import (
|
||||||
from litellm.router import ModelInfo as RouterModelInfo
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm.router import updateDeployment
|
from litellm.router import updateDeployment
|
||||||
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
||||||
from litellm.secret_managers.aws_secret_manager import (
|
from litellm.secret_managers.aws_secret_manager import load_aws_kms
|
||||||
load_aws_kms,
|
|
||||||
load_aws_secret_manager,
|
|
||||||
)
|
|
||||||
from litellm.secret_managers.google_kms import load_google_kms
|
from litellm.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.secret_managers.main import (
|
from litellm.secret_managers.main import (
|
||||||
get_secret,
|
get_secret,
|
||||||
|
@ -1825,8 +1822,13 @@ class ProxyConfig:
|
||||||
key_management_system
|
key_management_system
|
||||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||||
):
|
):
|
||||||
### LOAD FROM AWS SECRET MANAGER ###
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
AWSSecretsManagerV2.load_aws_secret_manager(
|
||||||
|
use_aws_secret_manager=True
|
||||||
|
)
|
||||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||||
load_aws_kms(use_aws_kms=True)
|
load_aws_kms(use_aws_kms=True)
|
||||||
elif (
|
elif (
|
||||||
|
|
|
@ -23,28 +23,6 @@ def validate_environment():
|
||||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
|
||||||
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
|
||||||
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
import boto3
|
|
||||||
from botocore.exceptions import ClientError
|
|
||||||
|
|
||||||
validate_environment()
|
|
||||||
|
|
||||||
# Create a Secrets Manager client
|
|
||||||
session = boto3.session.Session() # type: ignore
|
|
||||||
client = session.client(
|
|
||||||
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
|
|
||||||
)
|
|
||||||
|
|
||||||
litellm.secret_manager_client = client
|
|
||||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||||
if use_aws_kms is None or use_aws_kms is False:
|
if use_aws_kms is None or use_aws_kms is False:
|
||||||
return
|
return
|
||||||
|
|
310
litellm/secret_managers/aws_secret_manager_v2.py
Normal file
310
litellm/secret_managers/aws_secret_manager_v2.py
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
"""
|
||||||
|
This is a file for the AWS Secret Manager Integration
|
||||||
|
|
||||||
|
Handles Async Operations for:
|
||||||
|
- Read Secret
|
||||||
|
- Write Secret
|
||||||
|
- Delete Secret
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
* `os.environ["AWS_REGION_NAME"],
|
||||||
|
* `pip install boto3>=1.28.57`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.types import httpxSpecialProvider
|
||||||
|
from litellm.proxy._types import KeyManagementSystem
|
||||||
|
|
||||||
|
|
||||||
|
class AWSSecretsManagerV2(BaseAWSLLM):
|
||||||
|
@classmethod
|
||||||
|
def validate_environment(cls):
|
||||||
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
|
||||||
|
"""
|
||||||
|
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
"""
|
||||||
|
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
cls.validate_environment()
|
||||||
|
litellm.secret_manager_client = cls()
|
||||||
|
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def async_read_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Async function to read a secret from AWS Secrets Manager
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Secret value
|
||||||
|
Raises:
|
||||||
|
ValueError: If the secret is not found or an HTTP error occurs
|
||||||
|
"""
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="GetSecretValue",
|
||||||
|
secret_name=secret_name,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SecretManager,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["SecretString"]
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Error reading secret from AWS Secrets Manager: %s", str(e)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sync_read_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Sync function to read a secret from AWS Secrets Manager
|
||||||
|
|
||||||
|
Done for backwards compatibility with existing codebase, since get_secret is a sync function
|
||||||
|
"""
|
||||||
|
|
||||||
|
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
|
||||||
|
if secret_name in [
|
||||||
|
"AWS_ACCESS_KEY_ID",
|
||||||
|
"AWS_SECRET_ACCESS_KEY",
|
||||||
|
"AWS_REGION_NAME",
|
||||||
|
"AWS_REGION",
|
||||||
|
"AWS_BEDROCK_RUNTIME_ENDPOINT",
|
||||||
|
]:
|
||||||
|
return os.getenv(secret_name)
|
||||||
|
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="GetSecretValue",
|
||||||
|
secret_name=secret_name,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_client = _get_httpx_client(
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = sync_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["SecretString"]
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Error reading secret from AWS Secrets Manager: %s", str(e)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_write_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
secret_value: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
client_request_token: Optional[str] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Async function to write a secret to AWS Secrets Manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_name: Name of the secret
|
||||||
|
secret_value: Value to store (can be a JSON string)
|
||||||
|
description: Optional description for the secret
|
||||||
|
client_request_token: Optional unique identifier to ensure idempotency
|
||||||
|
optional_params: Additional AWS parameters
|
||||||
|
timeout: Request timeout
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Prepare the request data
|
||||||
|
data = {"Name": secret_name, "SecretString": secret_value}
|
||||||
|
if description:
|
||||||
|
data["Description"] = description
|
||||||
|
|
||||||
|
data["ClientRequestToken"] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="CreateSecret",
|
||||||
|
secret_name=secret_name,
|
||||||
|
secret_value=secret_value,
|
||||||
|
optional_params=optional_params,
|
||||||
|
request_data=data, # Pass the complete request data
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SecretManager,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
|
||||||
|
async def async_delete_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
recovery_window_in_days: Optional[int] = 7,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Async function to delete a secret from AWS Secrets Manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_name: Name of the secret to delete
|
||||||
|
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||||
|
optional_params: Additional AWS parameters
|
||||||
|
timeout: Request timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Response from AWS Secrets Manager containing deletion details
|
||||||
|
"""
|
||||||
|
# Prepare the request data
|
||||||
|
data = {
|
||||||
|
"SecretId": secret_name,
|
||||||
|
"RecoveryWindowInDays": recovery_window_in_days,
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="DeleteSecret",
|
||||||
|
secret_name=secret_name,
|
||||||
|
optional_params=optional_params,
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SecretManager,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
action: str, # "GetSecretValue" or "PutSecretValue"
|
||||||
|
secret_name: str,
|
||||||
|
secret_value: Optional[str] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
request_data: Optional[dict] = None,
|
||||||
|
) -> tuple[str, Any, bytes]:
|
||||||
|
"""Prepare the AWS Secrets Manager request"""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
optional_params = optional_params or {}
|
||||||
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||||
|
optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get endpoint
|
||||||
|
_, endpoint_url = self.get_runtime_endpoint(
|
||||||
|
api_base=None,
|
||||||
|
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||||
|
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||||
|
)
|
||||||
|
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
|
||||||
|
|
||||||
|
# Use provided request_data if available, otherwise build default data
|
||||||
|
if request_data:
|
||||||
|
data = request_data
|
||||||
|
else:
|
||||||
|
data = {"SecretId": secret_name}
|
||||||
|
if secret_value and action == "PutSecretValue":
|
||||||
|
data["SecretString"] = secret_value
|
||||||
|
|
||||||
|
body = json.dumps(data).encode("utf-8")
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/x-amz-json-1.1",
|
||||||
|
"X-Amz-Target": f"secretsmanager.{action}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sign request
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=body, headers=headers
|
||||||
|
)
|
||||||
|
SigV4Auth(
|
||||||
|
boto3_credentials_info.credentials,
|
||||||
|
"secretsmanager",
|
||||||
|
boto3_credentials_info.aws_region_name,
|
||||||
|
).add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
return endpoint_url, prepped.headers, body
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# print("loading aws secret manager v2")
|
||||||
|
# aws_secret_manager_v2 = AWSSecretsManagerV2()
|
||||||
|
|
||||||
|
# print("writing secret to aws secret manager v2")
|
||||||
|
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
|
||||||
|
# print("reading secret from aws secret manager v2")
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
|
||||||
raise ValueError("Unsupported OIDC provider")
|
raise ValueError("Unsupported OIDC provider")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if litellm.secret_manager_client is not None:
|
if (
|
||||||
|
_should_read_secret_from_secret_manager()
|
||||||
|
and litellm.secret_manager_client is not None
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
client = litellm.secret_manager_client
|
client = litellm.secret_manager_client
|
||||||
key_manager = "local"
|
key_manager = "local"
|
||||||
|
@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
|
||||||
|
|
||||||
if key_management_settings is not None:
|
if key_management_settings is not None:
|
||||||
if (
|
if (
|
||||||
secret_name not in key_management_settings.hosted_keys
|
key_management_settings.hosted_keys is not None
|
||||||
|
and secret_name not in key_management_settings.hosted_keys
|
||||||
): # allow user to specify which keys to check in hosted key manager
|
): # allow user to specify which keys to check in hosted key manager
|
||||||
key_manager = "local"
|
key_manager = "local"
|
||||||
|
|
||||||
|
@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
|
||||||
if isinstance(secret, str):
|
if isinstance(secret, str):
|
||||||
secret = secret.strip()
|
secret = secret.strip()
|
||||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||||
try:
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
get_secret_value_response = client.get_secret_value(
|
AWSSecretsManagerV2,
|
||||||
SecretId=secret_name
|
)
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"get_secret_value_response: {get_secret_value_response}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(f"An error occurred - {str(e)}")
|
|
||||||
# For a list of exceptions thrown, see
|
|
||||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# assume there is 1 secret per secret_name
|
if isinstance(client, AWSSecretsManagerV2):
|
||||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
secret = client.sync_read_secret(secret_name=secret_name)
|
||||||
print_verbose(f"secret_dict: {secret_dict}")
|
print_verbose(f"get_secret_value_response: {secret}")
|
||||||
for k, v in secret_dict.items():
|
|
||||||
secret = v
|
|
||||||
print_verbose(f"secret: {secret}")
|
|
||||||
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||||
try:
|
try:
|
||||||
secret = client.get_secret_from_google_secret_manager(
|
secret = client.get_secret_from_google_secret_manager(
|
||||||
|
@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
|
||||||
return default_value
|
return default_value
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _should_read_secret_from_secret_manager() -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the secret manager should be used to read the secret, False otherwise
|
||||||
|
|
||||||
|
- If the secret manager client is not set, return False
|
||||||
|
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
|
||||||
|
- Otherwise, return False
|
||||||
|
"""
|
||||||
|
if litellm.secret_manager_client is not None:
|
||||||
|
if litellm._key_management_settings is not None:
|
||||||
|
if (
|
||||||
|
litellm._key_management_settings.access_mode == "read_only"
|
||||||
|
or litellm._key_management_settings.access_mode == "read_and_write"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
139
tests/local_testing/test_aws_secret_manager.py
Normal file
139
tests/local_testing/test_aws_secret_manager.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
# What is this?
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm.types
|
||||||
|
import litellm.types.utils
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Ensure the project root is in the Python path
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
||||||
|
|
||||||
|
print("Python Path:", sys.path)
|
||||||
|
print("Current Working Directory:", os.getcwd())
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||||
|
|
||||||
|
|
||||||
|
def check_aws_credentials():
|
||||||
|
"""Helper function to check if AWS credentials are set"""
|
||||||
|
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
|
||||||
|
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
||||||
|
if missing_vars:
|
||||||
|
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_and_read_simple_secret():
|
||||||
|
"""Test writing and reading a simple string secret"""
|
||||||
|
check_aws_credentials()
|
||||||
|
|
||||||
|
secret_manager = AWSSecretsManagerV2()
|
||||||
|
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
|
||||||
|
test_secret_value = "test_value_123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write secret
|
||||||
|
write_response = await secret_manager.async_write_secret(
|
||||||
|
secret_name=test_secret_name,
|
||||||
|
secret_value=test_secret_value,
|
||||||
|
description="LiteLLM Test Secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Write Response:", write_response)
|
||||||
|
|
||||||
|
assert write_response is not None
|
||||||
|
assert "ARN" in write_response
|
||||||
|
assert "Name" in write_response
|
||||||
|
assert write_response["Name"] == test_secret_name
|
||||||
|
|
||||||
|
# Read secret back
|
||||||
|
read_value = await secret_manager.async_read_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Read Value:", read_value)
|
||||||
|
|
||||||
|
assert read_value == test_secret_value
|
||||||
|
finally:
|
||||||
|
# Cleanup: Delete the secret
|
||||||
|
delete_response = await secret_manager.async_delete_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
print("Delete Response:", delete_response)
|
||||||
|
assert delete_response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_and_read_json_secret():
|
||||||
|
"""Test writing and reading a JSON structured secret"""
|
||||||
|
check_aws_credentials()
|
||||||
|
|
||||||
|
secret_manager = AWSSecretsManagerV2()
|
||||||
|
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
|
||||||
|
test_secret_value = {
|
||||||
|
"api_key": "test_key",
|
||||||
|
"model": "gpt-4",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"metadata": {"team": "ml", "project": "litellm"},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write JSON secret
|
||||||
|
write_response = await secret_manager.async_write_secret(
|
||||||
|
secret_name=test_secret_name,
|
||||||
|
secret_value=json.dumps(test_secret_value),
|
||||||
|
description="LiteLLM JSON Test Secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Write Response:", write_response)
|
||||||
|
|
||||||
|
# Read and parse JSON secret
|
||||||
|
read_value = await secret_manager.async_read_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
parsed_value = json.loads(read_value)
|
||||||
|
|
||||||
|
print("Read Value:", read_value)
|
||||||
|
|
||||||
|
assert parsed_value == test_secret_value
|
||||||
|
assert parsed_value["api_key"] == "test_key"
|
||||||
|
assert parsed_value["metadata"]["team"] == "ml"
|
||||||
|
finally:
|
||||||
|
# Cleanup: Delete the secret
|
||||||
|
delete_response = await secret_manager.async_delete_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
print("Delete Response:", delete_response)
|
||||||
|
assert delete_response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_nonexistent_secret():
|
||||||
|
"""Test reading a secret that doesn't exist"""
|
||||||
|
check_aws_credentials()
|
||||||
|
|
||||||
|
secret_manager = AWSSecretsManagerV2()
|
||||||
|
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
|
||||||
|
|
||||||
|
assert response is None
|
|
@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries=3
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
|
|
@ -15,22 +15,29 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
import litellm
|
||||||
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import (
|
||||||
|
get_secret,
|
||||||
|
_should_read_secret_from_secret_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
||||||
def test_aws_secret_manager():
|
def test_aws_secret_manager():
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
import json
|
||||||
|
|
||||||
|
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
|
||||||
|
|
||||||
secret_val = get_secret("litellm_master_key")
|
secret_val = get_secret("litellm_master_key")
|
||||||
|
|
||||||
print(f"secret_val: {secret_val}")
|
print(f"secret_val: {secret_val}")
|
||||||
|
|
||||||
assert secret_val == "sk-1234"
|
# cast json to dict
|
||||||
|
secret_val = json.loads(secret_val)
|
||||||
|
|
||||||
|
assert secret_val["litellm_master_key"] == "sk-1234"
|
||||||
|
|
||||||
|
|
||||||
def redact_oidc_signature(secret_val):
|
def redact_oidc_signature(secret_val):
|
||||||
|
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
|
||||||
)
|
)
|
||||||
print("secret_val: {}".format(secret_val))
|
print("secret_val: {}".format(secret_val))
|
||||||
assert secret_val == "lite-llm"
|
assert secret_val == "lite-llm"
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_read_secret_from_secret_manager():
|
||||||
|
"""
|
||||||
|
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
|
||||||
|
"""
|
||||||
|
from litellm.proxy._types import KeyManagementSettings
|
||||||
|
|
||||||
|
# Test when secret manager client is None
|
||||||
|
litellm.secret_manager_client = None
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
assert _should_read_secret_from_secret_manager() is False
|
||||||
|
|
||||||
|
# Test with secret manager client and read_only access
|
||||||
|
litellm.secret_manager_client = "dummy_client"
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with secret manager client and read_and_write access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
access_mode="read_and_write"
|
||||||
|
)
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with secret manager client and write_only access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||||
|
assert _should_read_secret_from_secret_manager() is False
|
||||||
|
|
||||||
|
# Reset global variables
|
||||||
|
litellm.secret_manager_client = None
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_secret_with_access_mode():
|
||||||
|
"""
|
||||||
|
Test that get_secret respects access mode settings
|
||||||
|
"""
|
||||||
|
from litellm.proxy._types import KeyManagementSettings
|
||||||
|
|
||||||
|
# Set up test environment
|
||||||
|
test_secret_name = "TEST_SECRET_KEY"
|
||||||
|
test_secret_value = "test_secret_value"
|
||||||
|
os.environ[test_secret_name] = test_secret_value
|
||||||
|
|
||||||
|
# Test with write_only access (should read from os.environ)
|
||||||
|
litellm.secret_manager_client = "dummy_client"
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||||
|
assert get_secret(test_secret_name) == test_secret_value
|
||||||
|
|
||||||
|
# Test with no KeyManagementSettings but secret_manager_client set
|
||||||
|
litellm.secret_manager_client = "dummy_client"
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with read_only access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with read_and_write access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
access_mode="read_and_write"
|
||||||
|
)
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Reset global variables
|
||||||
|
litellm.secret_manager_client = None
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
del os.environ[test_secret_name]
|
||||||
|
|
|
@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
|
||||||
request=request,
|
request=request,
|
||||||
api_key="Bearer sk-123456789",
|
api_key="Bearer sk-123456789",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## E2E Virtual Key + Secret Manager Tests #########################################
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_generate_with_secret_manager_call(prisma_client):
|
||||||
|
"""
|
||||||
|
Generate a key
|
||||||
|
assert it exists in the secret manager
|
||||||
|
|
||||||
|
delete the key
|
||||||
|
assert it is deleted from the secret manager
|
||||||
|
"""
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||||
|
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
#### Test Setup ############################################################
|
||||||
|
aws_secret_manager_client = AWSSecretsManagerV2()
|
||||||
|
litellm.secret_manager_client = aws_secret_manager_client
|
||||||
|
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
store_virtual_keys=True,
|
||||||
|
)
|
||||||
|
general_settings = {
|
||||||
|
"key_management_system": "aws_secret_manager",
|
||||||
|
"key_management_settings": {
|
||||||
|
"store_virtual_keys": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
############################################################################
|
||||||
|
|
||||||
|
# generate new key
|
||||||
|
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
|
||||||
|
spend = 100
|
||||||
|
max_budget = 400
|
||||||
|
models = ["fake-openai-endpoint"]
|
||||||
|
new_key = await generate_key_fn(
|
||||||
|
data=GenerateKeyRequest(
|
||||||
|
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
||||||
|
),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_key = new_key.key
|
||||||
|
print(generated_key)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# read from the secret manager
|
||||||
|
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||||
|
|
||||||
|
# Assert the correct key is stored in the secret manager
|
||||||
|
print("response from AWS Secret Manager")
|
||||||
|
print(result)
|
||||||
|
assert result == generated_key
|
||||||
|
|
||||||
|
# delete the key
|
||||||
|
await delete_key_fn(
|
||||||
|
data=KeyRequest(keys=[generated_key]),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Assert the key is deleted from the secret manager
|
||||||
|
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue