diff --git a/docs/my-website/docs/secret.md b/docs/my-website/docs/secret.md index db5ec6910..15480ea3d 100644 --- a/docs/my-website/docs/secret.md +++ b/docs/my-website/docs/secret.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # Secret Manager LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager @@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2 ``` 2. Enable AWS Secret Manager in config. + + + + ```yaml general_settings: master_key: os.environ/litellm_master_key key_management_system: "aws_secret_manager" # 👈 KEY CHANGE key_management_settings: hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS + ``` + + + + +This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager. + +```yaml +general_settings: + key_management_system: "aws_secret_manager" # 👈 KEY CHANGE + key_management_settings: + store_virtual_keys: true + access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"] +``` + + + 3. Run proxy ```bash @@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml Use encrypted keys from Google KMS on the proxy -### Usage with LiteLLM Proxy Server - -## Step 1. Add keys to env +Step 1. Add keys to env ``` export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json" export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*" export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...' ``` -## Step 2: Update Config +Step 2: Update Config ```yaml general_settings: @@ -199,7 +221,7 @@ general_settings: master_key: sk-1234 ``` -## Step 3: Start + test proxy +Step 3: Start + test proxy ``` $ litellm --config /path/to/config.yaml @@ -215,3 +237,17 @@ $ litellm --test + + +## All Secret Manager Settings + +All settings related to secret management + +```yaml +general_settings: + key_management_system: "aws_secret_manager" # REQUIRED + key_management_settings: + store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager + access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only" + hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS +``` \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index 9812de1d8..5fdc9d0fc 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -304,7 +304,7 @@ secret_manager_client: Optional[Any] = ( ) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None -_key_management_settings: Optional[KeyManagementSettings] = None +_key_management_settings: KeyManagementSettings = KeyManagementSettings() #### PII MASKING #### output_parse_pii: bool = False ############################################# diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py index dc0958118..8e6ad0eda 100644 --- a/litellm/llms/custom_httpx/types.py +++ b/litellm/llms/custom_httpx/types.py @@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum): GuardrailCallback = "guardrail_callback" Caching = "caching" Oauth2Check = "oauth2_check" + SecretManager = "secret_manager" diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2d869af85..4baf13b61 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum): class KeyManagementSettings(LiteLLMBase): - hosted_keys: List + hosted_keys: Optional[List] = None + store_virtual_keys: Optional[bool] = False + """ + If True, virtual keys created by litellm will be stored in the secret manager + """ + + access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only" + """ + Access mode for the secret manager, when write_only will only use for writing secrets + """ class TeamDefaultSettings(LiteLLMBase): diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py new file mode 100644 index 000000000..08645a468 --- /dev/null +++ b/litellm/proxy/hooks/key_management_event_hooks.py @@ -0,0 +1,267 @@ +import asyncio +import json +import uuid +from datetime import datetime, timezone +from re import A +from typing import Any, List, Optional + +from fastapi import status + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + GenerateKeyRequest, + KeyManagementSystem, + KeyRequest, + LiteLLM_AuditLogs, + LiteLLM_VerificationToken, + LitellmTableNames, + ProxyErrorTypes, + ProxyException, + UpdateKeyRequest, + UserAPIKeyAuth, + WebhookEvent, +) + + +class KeyManagementEventHooks: + + @staticmethod + async def async_key_generated_hook( + data: GenerateKeyRequest, + response: dict, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Hook that runs after a successful /key/generate request + + Handles the following: + - Sending Email with Key Details + - Storing Audit Logs for key generation + - Storing Generated Key in DB + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import ( + general_settings, + litellm_proxy_admin_name, + proxy_logging_obj, + ) + + if data.send_invite_email is True: + await KeyManagementEventHooks._send_key_created_email(response) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(response, default=str) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=response.get("token_id", ""), + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + # store the generated key in the secret manager + await KeyManagementEventHooks._store_virtual_key_in_secret_manager( + secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}", + secret_token=response.get("token", ""), + ) + + @staticmethod + async def async_key_updated_hook( + data: UpdateKeyRequest, + existing_key_row: Any, + response: Any, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/update processing hook + + Handles the following: + - Storing Audit Logs for key update + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(data.json(exclude_none=True), default=str) + + _before_value = existing_key_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=data.key, + action="updated", + updated_values=_updated_values, + before_value=_before_value, + ) + ) + ) + pass + + @staticmethod + async def async_key_deleted_hook( + data: KeyRequest, + keys_being_deleted: List[LiteLLM_VerificationToken], + response: dict, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str] = None, + ): + """ + Post /key/delete processing hook + + Handles the following: + - Storing Audit Logs for key deletion + """ + from litellm.proxy.management_helpers.audit_logs import ( + create_audit_log_for_update, + ) + from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for key in data.keys: + key_row = await prisma_client.get_data( # type: ignore + token=key, table_name="key", query_type="find_unique" + ) + + if key_row is None: + raise ProxyException( + message=f"Key {key} not found", + type=ProxyErrorTypes.bad_request_error, + param="key", + code=status.HTTP_404_NOT_FOUND, + ) + + key_row = key_row.json(exclude_none=True) + _key_row = json.dumps(key_row, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=key, + action="deleted", + updated_values="{}", + before_value=_key_row, + ) + ) + ) + # delete the keys from the secret manager + await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager( + keys_being_deleted=keys_being_deleted + ) + pass + + @staticmethod + async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str): + """ + Store a virtual key in the secret manager + + Args: + secret_name: Name of the virtual key + secret_token: Value of the virtual key (example: sk-1234) + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + # store the key in the secret manager + if ( + litellm._key_management_system + == KeyManagementSystem.AWS_SECRET_MANAGER + and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2) + ): + await litellm.secret_manager_client.async_write_secret( + secret_name=secret_name, + secret_value=secret_token, + ) + + @staticmethod + async def _delete_virtual_keys_from_secret_manager( + keys_being_deleted: List[LiteLLM_VerificationToken], + ): + """ + Deletes virtual keys from the secret manager + + Args: + keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation + """ + if litellm._key_management_settings is not None: + if litellm._key_management_settings.store_virtual_keys is True: + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2): + for key in keys_being_deleted: + if key.key_alias is not None: + await litellm.secret_manager_client.async_delete_secret( + secret_name=key.key_alias + ) + else: + verbose_proxy_logger.warning( + f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager." + ) + + @staticmethod + async def _send_key_created_email(response: dict): + from litellm.proxy.proxy_server import general_settings, proxy_logging_obj + + if "email" not in general_settings.get("alerting", []): + raise ValueError( + "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" + ) + event = WebhookEvent( + event="key_created", + event_group="key", + event_message="API Key Created", + token=response.get("token", ""), + spend=response.get("spend", 0.0), + max_budget=response.get("max_budget", 0.0), + user_id=response.get("user_id", None), + team_id=response.get("team_id", "Default Team"), + key_alias=response.get("key_alias", None), + ) + + # If user configured email alerting - send an Email letting their end-user know the key was created + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( + webhook_event=event, + ) + ) diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 01baa5a43..e38236e9b 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -17,7 +17,7 @@ import secrets import traceback import uuid from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import List, Optional, Tuple import fastapi from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status @@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import ( get_key_object, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed from litellm.secret_managers.main import get_secret @@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915 data.soft_budget ) # include the user-input soft budget in the response - if data.send_invite_email is True: - if "email" not in general_settings.get("alerting", []): - raise ValueError( - "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" - ) - event = WebhookEvent( - event="key_created", - event_group="key", - event_message="API Key Created", - token=response.get("token", ""), - spend=response.get("spend", 0.0), - max_budget=response.get("max_budget", 0.0), - user_id=response.get("user_id", None), - team_id=response.get("team_id", "Default Team"), - key_alias=response.get("key_alias", None), - ) - - # If user configured email alerting - send an Email letting their end-user know the key was created - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( - webhook_event=event, - ) - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(response, default=str) - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=response.get("token_id", ""), - action="created", - updated_values=_updated_values, - before_value=None, - ) - ) + asyncio.create_task( + KeyManagementEventHooks.async_key_generated_hook( + data=data, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, ) + ) return GenerateKeyResponse(**response) except Exception as e: @@ -407,30 +372,15 @@ async def update_key_fn( proxy_logging_obj=proxy_logging_obj, ) - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(data_json, default=str) - - _before_value = existing_key_row.json(exclude_none=True) - _before_value = json.dumps(_before_value, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=data.key, - action="updated", - updated_values=_updated_values, - before_value=_before_value, - ) - ) + asyncio.create_task( + KeyManagementEventHooks.async_key_updated_hook( + data=data, + existing_key_row=existing_key_row, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, ) + ) if response is None: raise ValueError("Failed to update key got response = None") @@ -496,6 +446,9 @@ async def delete_key_fn( user_custom_key_generate, ) + if prisma_client is None: + raise Exception("Not connected to DB!") + keys = data.keys if len(keys) == 0: raise ProxyException( @@ -516,45 +469,7 @@ async def delete_key_fn( ): user_id = None # unless they're admin - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes - if litellm.store_audit_logs is True: - # make an audit log for each team deleted - for key in data.keys: - key_row = await prisma_client.get_data( # type: ignore - token=key, table_name="key", query_type="find_unique" - ) - - if key_row is None: - raise ProxyException( - message=f"Key {key} not found", - type=ProxyErrorTypes.bad_request_error, - param="key", - code=status.HTTP_404_NOT_FOUND, - ) - - key_row = key_row.json(exclude_none=True) - _key_row = json.dumps(key_row, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=key, - action="deleted", - updated_values="{}", - before_value=_key_row, - ) - ) - ) - - number_deleted_keys = await delete_verification_token( + number_deleted_keys, _keys_being_deleted = await delete_verification_token( tokens=keys, user_id=user_id ) if number_deleted_keys is None: @@ -588,6 +503,16 @@ async def delete_key_fn( f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" ) + asyncio.create_task( + KeyManagementEventHooks.async_key_deleted_hook( + data=data, + keys_being_deleted=_keys_being_deleted, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + response=number_deleted_keys, + ) + ) + return {"deleted_keys": keys} except Exception as e: if isinstance(e, HTTPException): @@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915 return key_data -async def delete_verification_token(tokens: List, user_id: Optional[str] = None): +async def delete_verification_token( + tokens: List, user_id: Optional[str] = None +) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + """ + Helper that deletes the list of tokens from the database + + Args: + tokens: List of tokens to delete + user_id: Optional user_id to filter by + + Returns: + Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]: + Optional[Dict]: + - Number of deleted tokens + List[LiteLLM_VerificationToken]: + - List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted, + this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs + """ from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client try: if prisma_client: + tokens = [_hash_token_if_needed(token=key) for key in tokens] + _keys_being_deleted = ( + await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} + ) + ) + # Assuming 'db' is your Prisma Client instance # check if admin making request - don't filter by user-id if user_id == litellm_proxy_admin_name: @@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None) ) verbose_proxy_logger.debug(traceback.format_exc()) raise e - return deleted_tokens + return deleted_tokens, _keys_being_deleted @router.post( diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index f9f8276c7..094828de1 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915 ProxyConfig, app, load_aws_kms, - load_aws_secret_manager, load_from_azure_key_vault, load_google_kms, save_worker_config, @@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915 key_management_system == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + ### LOAD FROM AWS SECRET MANAGER ### - load_aws_secret_manager(use_aws_secret_manager=True) + AWSSecretsManagerV2.load_aws_secret_manager( + use_aws_secret_manager=True + ) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) elif ( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 29d14c910..71e3dee0e 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -7,6 +7,8 @@ model_list: -litellm_settings: - callbacks: ["gcs_bucket"] - +general_settings: + key_management_system: "aws_secret_manager" + key_management_settings: + store_virtual_keys: true + access_mode: "write_only" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c9c6af77f..34ac51481 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -245,10 +245,7 @@ from litellm.router import ( from litellm.router import ModelInfo as RouterModelInfo from litellm.router import updateDeployment from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler -from litellm.secret_managers.aws_secret_manager import ( - load_aws_kms, - load_aws_secret_manager, -) +from litellm.secret_managers.aws_secret_manager import load_aws_kms from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.main import ( get_secret, @@ -1825,8 +1822,13 @@ class ProxyConfig: key_management_system == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 ): - ### LOAD FROM AWS SECRET MANAGER ### - load_aws_secret_manager(use_aws_secret_manager=True) + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + AWSSecretsManagerV2.load_aws_secret_manager( + use_aws_secret_manager=True + ) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) elif ( diff --git a/litellm/secret_managers/aws_secret_manager.py b/litellm/secret_managers/aws_secret_manager.py index f0e510fa8..fbe951e64 100644 --- a/litellm/secret_managers/aws_secret_manager.py +++ b/litellm/secret_managers/aws_secret_manager.py @@ -23,28 +23,6 @@ def validate_environment(): raise ValueError("Missing required environment variable - AWS_REGION_NAME") -def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]): - if use_aws_secret_manager is None or use_aws_secret_manager is False: - return - try: - import boto3 - from botocore.exceptions import ClientError - - validate_environment() - - # Create a Secrets Manager client - session = boto3.session.Session() # type: ignore - client = session.client( - service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME") - ) - - litellm.secret_manager_client = client - litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER - - except Exception as e: - raise e - - def load_aws_kms(use_aws_kms: Optional[bool]): if use_aws_kms is None or use_aws_kms is False: return diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py new file mode 100644 index 000000000..69add6f23 --- /dev/null +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -0,0 +1,310 @@ +""" +This is a file for the AWS Secret Manager Integration + +Handles Async Operations for: +- Read Secret +- Write Secret +- Delete Secret + +Relevant issue: https://github.com/BerriAI/litellm/issues/1883 + +Requires: +* `os.environ["AWS_REGION_NAME"], +* `pip install boto3>=1.28.57` +""" + +import ast +import asyncio +import base64 +import json +import os +import re +import sys +from typing import Any, Dict, Optional, Union + +import httpx + +import litellm +from litellm._logging import verbose_logger +from litellm.llms.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, +) +from litellm.llms.custom_httpx.types import httpxSpecialProvider +from litellm.proxy._types import KeyManagementSystem + + +class AWSSecretsManagerV2(BaseAWSLLM): + @classmethod + def validate_environment(cls): + if "AWS_REGION_NAME" not in os.environ: + raise ValueError("Missing required environment variable - AWS_REGION_NAME") + + @classmethod + def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]): + """ + Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + """ + if use_aws_secret_manager is None or use_aws_secret_manager is False: + return + try: + import boto3 + + cls.validate_environment() + litellm.secret_manager_client = cls() + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + + except Exception as e: + raise e + + async def async_read_secret( + self, + secret_name: str, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> Optional[str]: + """ + Async function to read a secret from AWS Secrets Manager + + Returns: + str: Secret value + Raises: + ValueError: If the secret is not found or an HTTP error occurs + """ + endpoint_url, headers, body = self._prepare_request( + action="GetSecretValue", + secret_name=secret_name, + optional_params=optional_params, + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json()["SecretString"] + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + except Exception as e: + verbose_logger.exception( + "Error reading secret from AWS Secrets Manager: %s", str(e) + ) + return None + + def sync_read_secret( + self, + secret_name: str, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> Optional[str]: + """ + Sync function to read a secret from AWS Secrets Manager + + Done for backwards compatibility with existing codebase, since get_secret is a sync function + """ + + # self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop + if secret_name in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION_NAME", + "AWS_REGION", + "AWS_BEDROCK_RUNTIME_ENDPOINT", + ]: + return os.getenv(secret_name) + + endpoint_url, headers, body = self._prepare_request( + action="GetSecretValue", + secret_name=secret_name, + optional_params=optional_params, + ) + + sync_client = _get_httpx_client( + params={"timeout": timeout}, + ) + + try: + response = sync_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json()["SecretString"] + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + except Exception as e: + verbose_logger.exception( + "Error reading secret from AWS Secrets Manager: %s", str(e) + ) + return None + + async def async_write_secret( + self, + secret_name: str, + secret_value: str, + description: Optional[str] = None, + client_request_token: Optional[str] = None, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> dict: + """ + Async function to write a secret to AWS Secrets Manager + + Args: + secret_name: Name of the secret + secret_value: Value to store (can be a JSON string) + description: Optional description for the secret + client_request_token: Optional unique identifier to ensure idempotency + optional_params: Additional AWS parameters + timeout: Request timeout + """ + import uuid + + # Prepare the request data + data = {"Name": secret_name, "SecretString": secret_value} + if description: + data["Description"] = description + + data["ClientRequestToken"] = str(uuid.uuid4()) + + endpoint_url, headers, body = self._prepare_request( + action="CreateSecret", + secret_name=secret_name, + secret_value=secret_value, + optional_params=optional_params, + request_data=data, # Pass the complete request data + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as err: + raise ValueError(f"HTTP error occurred: {err.response.text}") + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + + async def async_delete_secret( + self, + secret_name: str, + recovery_window_in_days: Optional[int] = 7, + optional_params: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ) -> dict: + """ + Async function to delete a secret from AWS Secrets Manager + + Args: + secret_name: Name of the secret to delete + recovery_window_in_days: Number of days before permanent deletion (default: 7) + optional_params: Additional AWS parameters + timeout: Request timeout + + Returns: + dict: Response from AWS Secrets Manager containing deletion details + """ + # Prepare the request data + data = { + "SecretId": secret_name, + "RecoveryWindowInDays": recovery_window_in_days, + } + + endpoint_url, headers, body = self._prepare_request( + action="DeleteSecret", + secret_name=secret_name, + optional_params=optional_params, + request_data=data, + ) + + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + params={"timeout": timeout}, + ) + + try: + response = await async_client.post( + url=endpoint_url, headers=headers, data=body.decode("utf-8") + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as err: + raise ValueError(f"HTTP error occurred: {err.response.text}") + except httpx.TimeoutException: + raise ValueError("Timeout error occurred") + + def _prepare_request( + self, + action: str, # "GetSecretValue" or "PutSecretValue" + secret_name: str, + secret_value: Optional[str] = None, + optional_params: Optional[dict] = None, + request_data: Optional[dict] = None, + ) -> tuple[str, Any, bytes]: + """Prepare the AWS Secrets Manager request""" + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + optional_params = optional_params or {} + boto3_credentials_info = self._get_boto_credentials_from_optional_params( + optional_params + ) + + # Get endpoint + _, endpoint_url = self.get_runtime_endpoint( + api_base=None, + aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, + aws_region_name=boto3_credentials_info.aws_region_name, + ) + endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager") + + # Use provided request_data if available, otherwise build default data + if request_data: + data = request_data + else: + data = {"SecretId": secret_name} + if secret_value and action == "PutSecretValue": + data["SecretString"] = secret_value + + body = json.dumps(data).encode("utf-8") + headers = { + "Content-Type": "application/x-amz-json-1.1", + "X-Amz-Target": f"secretsmanager.{action}", + } + + # Sign request + request = AWSRequest( + method="POST", url=endpoint_url, data=body, headers=headers + ) + SigV4Auth( + boto3_credentials_info.credentials, + "secretsmanager", + boto3_credentials_info.aws_region_name, + ).add_auth(request) + prepped = request.prepare() + + return endpoint_url, prepped.headers, body + + +# if __name__ == "__main__": +# print("loading aws secret manager v2") +# aws_secret_manager_v2 = AWSSecretsManagerV2() + +# print("writing secret to aws secret manager v2") +# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2")) +# print("reading secret from aws secret manager v2") diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index f3d6d420a..ce6d30755 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -5,7 +5,7 @@ import json import os import sys import traceback -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import httpx from dotenv import load_dotenv @@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915 raise ValueError("Unsupported OIDC provider") try: - if litellm.secret_manager_client is not None: + if ( + _should_read_secret_from_secret_manager() + and litellm.secret_manager_client is not None + ): try: client = litellm.secret_manager_client key_manager = "local" @@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915 if key_management_settings is not None: if ( - secret_name not in key_management_settings.hosted_keys + key_management_settings.hosted_keys is not None + and secret_name not in key_management_settings.hosted_keys ): # allow user to specify which keys to check in hosted key manager key_manager = "local" @@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915 if isinstance(secret, str): secret = secret.strip() elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: - try: - get_secret_value_response = client.get_secret_value( - SecretId=secret_name - ) - print_verbose( - f"get_secret_value_response: {get_secret_value_response}" - ) - except Exception as e: - print_verbose(f"An error occurred - {str(e)}") - # For a list of exceptions thrown, see - # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html - raise e + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) - # assume there is 1 secret per secret_name - secret_dict = json.loads(get_secret_value_response["SecretString"]) - print_verbose(f"secret_dict: {secret_dict}") - for k, v in secret_dict.items(): - secret = v - print_verbose(f"secret: {secret}") + if isinstance(client, AWSSecretsManagerV2): + secret = client.sync_read_secret(secret_name=secret_name) + print_verbose(f"get_secret_value_response: {secret}") elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: try: secret = client.get_secret_from_google_secret_manager( @@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915 return default_value else: raise e + + +def _should_read_secret_from_secret_manager() -> bool: + """ + Returns True if the secret manager should be used to read the secret, False otherwise + + - If the secret manager client is not set, return False + - If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True + - Otherwise, return False + """ + if litellm.secret_manager_client is not None: + if litellm._key_management_settings is not None: + if ( + litellm._key_management_settings.access_mode == "read_only" + or litellm._key_management_settings.access_mode == "read_and_write" + ): + return True + return False diff --git a/tests/local_testing/test_aws_secret_manager.py b/tests/local_testing/test_aws_secret_manager.py new file mode 100644 index 000000000..f2e2319cc --- /dev/null +++ b/tests/local_testing/test_aws_secret_manager.py @@ -0,0 +1,139 @@ +# What is this? + +import asyncio +import os +import sys +import traceback + +from dotenv import load_dotenv + +import litellm.types +import litellm.types.utils + + +load_dotenv() +import io + +import sys +import os + +# Ensure the project root is in the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +print("Python Path:", sys.path) +print("Current Working Directory:", os.getcwd()) + + +from typing import Optional +from unittest.mock import MagicMock, patch + +import pytest +import uuid +import json +from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 + + +def check_aws_credentials(): + """Helper function to check if AWS credentials are set""" + required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"] + missing_vars = [var for var in required_vars if not os.getenv(var)] + if missing_vars: + pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}") + + +@pytest.mark.asyncio +async def test_write_and_read_simple_secret(): + """Test writing and reading a simple string secret""" + check_aws_credentials() + + secret_manager = AWSSecretsManagerV2() + test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}" + test_secret_value = "test_value_123" + + try: + # Write secret + write_response = await secret_manager.async_write_secret( + secret_name=test_secret_name, + secret_value=test_secret_value, + description="LiteLLM Test Secret", + ) + + print("Write Response:", write_response) + + assert write_response is not None + assert "ARN" in write_response + assert "Name" in write_response + assert write_response["Name"] == test_secret_name + + # Read secret back + read_value = await secret_manager.async_read_secret( + secret_name=test_secret_name + ) + + print("Read Value:", read_value) + + assert read_value == test_secret_value + finally: + # Cleanup: Delete the secret + delete_response = await secret_manager.async_delete_secret( + secret_name=test_secret_name + ) + print("Delete Response:", delete_response) + assert delete_response is not None + + +@pytest.mark.asyncio +async def test_write_and_read_json_secret(): + """Test writing and reading a JSON structured secret""" + check_aws_credentials() + + secret_manager = AWSSecretsManagerV2() + test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json" + test_secret_value = { + "api_key": "test_key", + "model": "gpt-4", + "temperature": 0.7, + "metadata": {"team": "ml", "project": "litellm"}, + } + + try: + # Write JSON secret + write_response = await secret_manager.async_write_secret( + secret_name=test_secret_name, + secret_value=json.dumps(test_secret_value), + description="LiteLLM JSON Test Secret", + ) + + print("Write Response:", write_response) + + # Read and parse JSON secret + read_value = await secret_manager.async_read_secret( + secret_name=test_secret_name + ) + parsed_value = json.loads(read_value) + + print("Read Value:", read_value) + + assert parsed_value == test_secret_value + assert parsed_value["api_key"] == "test_key" + assert parsed_value["metadata"]["team"] == "ml" + finally: + # Cleanup: Delete the secret + delete_response = await secret_manager.async_delete_secret( + secret_name=test_secret_name + ) + print("Delete Response:", delete_response) + assert delete_response is not None + + +@pytest.mark.asyncio +async def test_read_nonexistent_secret(): + """Test reading a secret that doesn't exist""" + check_aws_credentials() + + secret_manager = AWSSecretsManagerV2() + nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}" + + response = await secret_manager.async_read_secret(secret_name=nonexistent_secret) + + assert response is None diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 7814d13c6..211a4cd19 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries = 3 +# litellm.num_retries=3 litellm.cache = None litellm.success_callback = [] diff --git a/tests/local_testing/test_secret_manager.py b/tests/local_testing/test_secret_manager.py index 397128ecb..1b95119a3 100644 --- a/tests/local_testing/test_secret_manager.py +++ b/tests/local_testing/test_secret_manager.py @@ -15,22 +15,29 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest - +import litellm from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM -from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager -from litellm.secret_managers.main import get_secret +from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 +from litellm.secret_managers.main import ( + get_secret, + _should_read_secret_from_secret_manager, +) -@pytest.mark.skip(reason="AWS Suspended Account") def test_aws_secret_manager(): - load_aws_secret_manager(use_aws_secret_manager=True) + import json + + AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True) secret_val = get_secret("litellm_master_key") print(f"secret_val: {secret_val}") - assert secret_val == "sk-1234" + # cast json to dict + secret_val = json.loads(secret_val) + + assert secret_val["litellm_master_key"] == "sk-1234" def redact_oidc_signature(secret_val): @@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory(): ) print("secret_val: {}".format(secret_val)) assert secret_val == "lite-llm" + + +def test_should_read_secret_from_secret_manager(): + """ + Test that _should_read_secret_from_secret_manager returns correct values based on access mode + """ + from litellm.proxy._types import KeyManagementSettings + + # Test when secret manager client is None + litellm.secret_manager_client = None + litellm._key_management_settings = KeyManagementSettings() + assert _should_read_secret_from_secret_manager() is False + + # Test with secret manager client and read_only access + litellm.secret_manager_client = "dummy_client" + litellm._key_management_settings = KeyManagementSettings(access_mode="read_only") + assert _should_read_secret_from_secret_manager() is True + + # Test with secret manager client and read_and_write access + litellm._key_management_settings = KeyManagementSettings( + access_mode="read_and_write" + ) + assert _should_read_secret_from_secret_manager() is True + + # Test with secret manager client and write_only access + litellm._key_management_settings = KeyManagementSettings(access_mode="write_only") + assert _should_read_secret_from_secret_manager() is False + + # Reset global variables + litellm.secret_manager_client = None + litellm._key_management_settings = KeyManagementSettings() + + +def test_get_secret_with_access_mode(): + """ + Test that get_secret respects access mode settings + """ + from litellm.proxy._types import KeyManagementSettings + + # Set up test environment + test_secret_name = "TEST_SECRET_KEY" + test_secret_value = "test_secret_value" + os.environ[test_secret_name] = test_secret_value + + # Test with write_only access (should read from os.environ) + litellm.secret_manager_client = "dummy_client" + litellm._key_management_settings = KeyManagementSettings(access_mode="write_only") + assert get_secret(test_secret_name) == test_secret_value + + # Test with no KeyManagementSettings but secret_manager_client set + litellm.secret_manager_client = "dummy_client" + litellm._key_management_settings = KeyManagementSettings() + assert _should_read_secret_from_secret_manager() is True + + # Test with read_only access + litellm._key_management_settings = KeyManagementSettings(access_mode="read_only") + assert _should_read_secret_from_secret_manager() is True + + # Test with read_and_write access + litellm._key_management_settings = KeyManagementSettings( + access_mode="read_and_write" + ) + assert _should_read_secret_from_secret_manager() is True + + # Reset global variables + litellm.secret_manager_client = None + litellm._key_management_settings = KeyManagementSettings() + del os.environ[test_secret_name] diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index 78b558cd2..b97ab3514 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed(): request=request, api_key="Bearer sk-123456789", ) + + +## E2E Virtual Key + Secret Manager Tests ######################################### + + +@pytest.mark.asyncio +async def test_key_generate_with_secret_manager_call(prisma_client): + """ + Generate a key + assert it exists in the secret manager + + delete the key + assert it is deleted from the secret manager + """ + from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 + from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings + + litellm.set_verbose = True + + #### Test Setup ############################################################ + aws_secret_manager_client = AWSSecretsManagerV2() + litellm.secret_manager_client = aws_secret_manager_client + litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER + litellm._key_management_settings = KeyManagementSettings( + store_virtual_keys=True, + ) + general_settings = { + "key_management_system": "aws_secret_manager", + "key_management_settings": { + "store_virtual_keys": True, + }, + } + + setattr(litellm.proxy.proxy_server, "general_settings", general_settings) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + await litellm.proxy.proxy_server.prisma_client.connect() + ############################################################################ + + # generate new key + key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}" + spend = 100 + max_budget = 400 + models = ["fake-openai-endpoint"] + new_key = await generate_key_fn( + data=GenerateKeyRequest( + key_alias=key_alias, spend=spend, max_budget=max_budget, models=models + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + + generated_key = new_key.key + print(generated_key) + + await asyncio.sleep(2) + + # read from the secret manager + result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias) + + # Assert the correct key is stored in the secret manager + print("response from AWS Secret Manager") + print(result) + assert result == generated_key + + # delete the key + await delete_key_fn( + data=KeyRequest(keys=[generated_key]), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234" + ), + ) + + await asyncio.sleep(2) + + # Assert the key is deleted from the secret manager + result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias) + assert result is None + + # cleanup + setattr(litellm.proxy.proxy_server, "general_settings", {}) + + +################################################################################