forked from phoenix/litellm-mirror
(Feat) Add support for storing virtual keys in AWS SecretManager (#6728)
* add SecretManager to httpxSpecialProvider * fix importing AWSSecretsManagerV2 * add unit testing for writing keys to AWS secret manager * use KeyManagementEventHooks for key/generated events * us event hooks for key management endpoints * working AWSSecretsManagerV2 * fix write secret to AWS secret manager on /key/generate * fix KeyManagementSettings * use tasks for key management hooks * add async_delete_secret * add test for async_delete_secret * use _delete_virtual_keys_from_secret_manager * fix test secret manager * test_key_generate_with_secret_manager_call * fix check for key_management_settings * sync_read_secret * test_aws_secret_manager * fix sync_read_secret * use helper to check when _should_read_secret_from_secret_manager * test_get_secret_with_access_mode * test - handle eol model claude-2, use claude-2.1 instead * docs AWS secret manager * fix test_read_nonexistent_secret * fix test_supports_response_schema * ci/cd run again
This commit is contained in:
parent
da84056e59
commit
f8e700064e
16 changed files with 1046 additions and 178 deletions
139
tests/local_testing/test_aws_secret_manager.py
Normal file
139
tests/local_testing/test_aws_secret_manager.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
# What is this?
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Ensure the project root is in the Python path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
||||
|
||||
print("Python Path:", sys.path)
|
||||
print("Current Working Directory:", os.getcwd())
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
import json
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
|
||||
|
||||
def check_aws_credentials():
|
||||
"""Helper function to check if AWS credentials are set"""
|
||||
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
|
||||
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
||||
if missing_vars:
|
||||
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_simple_secret():
|
||||
"""Test writing and reading a simple string secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
|
||||
test_secret_value = "test_value_123"
|
||||
|
||||
try:
|
||||
# Write secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=test_secret_name,
|
||||
secret_value=test_secret_value,
|
||||
description="LiteLLM Test Secret",
|
||||
)
|
||||
|
||||
print("Write Response:", write_response)
|
||||
|
||||
assert write_response is not None
|
||||
assert "ARN" in write_response
|
||||
assert "Name" in write_response
|
||||
assert write_response["Name"] == test_secret_name
|
||||
|
||||
# Read secret back
|
||||
read_value = await secret_manager.async_read_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
|
||||
print("Read Value:", read_value)
|
||||
|
||||
assert read_value == test_secret_value
|
||||
finally:
|
||||
# Cleanup: Delete the secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read_json_secret():
|
||||
"""Test writing and reading a JSON structured secret"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
|
||||
test_secret_value = {
|
||||
"api_key": "test_key",
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.7,
|
||||
"metadata": {"team": "ml", "project": "litellm"},
|
||||
}
|
||||
|
||||
try:
|
||||
# Write JSON secret
|
||||
write_response = await secret_manager.async_write_secret(
|
||||
secret_name=test_secret_name,
|
||||
secret_value=json.dumps(test_secret_value),
|
||||
description="LiteLLM JSON Test Secret",
|
||||
)
|
||||
|
||||
print("Write Response:", write_response)
|
||||
|
||||
# Read and parse JSON secret
|
||||
read_value = await secret_manager.async_read_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
parsed_value = json.loads(read_value)
|
||||
|
||||
print("Read Value:", read_value)
|
||||
|
||||
assert parsed_value == test_secret_value
|
||||
assert parsed_value["api_key"] == "test_key"
|
||||
assert parsed_value["metadata"]["team"] == "ml"
|
||||
finally:
|
||||
# Cleanup: Delete the secret
|
||||
delete_response = await secret_manager.async_delete_secret(
|
||||
secret_name=test_secret_name
|
||||
)
|
||||
print("Delete Response:", delete_response)
|
||||
assert delete_response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_nonexistent_secret():
|
||||
"""Test reading a secret that doesn't exist"""
|
||||
check_aws_credentials()
|
||||
|
||||
secret_manager = AWSSecretsManagerV2()
|
||||
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
|
||||
|
||||
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
|
||||
|
||||
assert response is None
|
|
@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
|||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
|
||||
# litellm.num_retries = 3
|
||||
# litellm.num_retries=3
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
|
|
|
@ -15,22 +15,29 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
from litellm.secret_managers.main import (
|
||||
get_secret,
|
||||
_should_read_secret_from_secret_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
def test_aws_secret_manager():
|
||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
import json
|
||||
|
||||
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
|
||||
secret_val = get_secret("litellm_master_key")
|
||||
|
||||
print(f"secret_val: {secret_val}")
|
||||
|
||||
assert secret_val == "sk-1234"
|
||||
# cast json to dict
|
||||
secret_val = json.loads(secret_val)
|
||||
|
||||
assert secret_val["litellm_master_key"] == "sk-1234"
|
||||
|
||||
|
||||
def redact_oidc_signature(secret_val):
|
||||
|
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
|
|||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
assert secret_val == "lite-llm"
|
||||
|
||||
|
||||
def test_should_read_secret_from_secret_manager():
|
||||
"""
|
||||
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
|
||||
"""
|
||||
from litellm.proxy._types import KeyManagementSettings
|
||||
|
||||
# Test when secret manager client is None
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
assert _should_read_secret_from_secret_manager() is False
|
||||
|
||||
# Test with secret manager client and read_only access
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with secret manager client and read_and_write access
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
access_mode="read_and_write"
|
||||
)
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with secret manager client and write_only access
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||
assert _should_read_secret_from_secret_manager() is False
|
||||
|
||||
# Reset global variables
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
|
||||
|
||||
def test_get_secret_with_access_mode():
|
||||
"""
|
||||
Test that get_secret respects access mode settings
|
||||
"""
|
||||
from litellm.proxy._types import KeyManagementSettings
|
||||
|
||||
# Set up test environment
|
||||
test_secret_name = "TEST_SECRET_KEY"
|
||||
test_secret_value = "test_secret_value"
|
||||
os.environ[test_secret_name] = test_secret_value
|
||||
|
||||
# Test with write_only access (should read from os.environ)
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||
assert get_secret(test_secret_name) == test_secret_value
|
||||
|
||||
# Test with no KeyManagementSettings but secret_manager_client set
|
||||
litellm.secret_manager_client = "dummy_client"
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with read_only access
|
||||
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Test with read_and_write access
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
access_mode="read_and_write"
|
||||
)
|
||||
assert _should_read_secret_from_secret_manager() is True
|
||||
|
||||
# Reset global variables
|
||||
litellm.secret_manager_client = None
|
||||
litellm._key_management_settings = KeyManagementSettings()
|
||||
del os.environ[test_secret_name]
|
||||
|
|
|
@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
|
|||
request=request,
|
||||
api_key="Bearer sk-123456789",
|
||||
)
|
||||
|
||||
|
||||
## E2E Virtual Key + Secret Manager Tests #########################################
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_generate_with_secret_manager_call(prisma_client):
|
||||
"""
|
||||
Generate a key
|
||||
assert it exists in the secret manager
|
||||
|
||||
delete the key
|
||||
assert it is deleted from the secret manager
|
||||
"""
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
#### Test Setup ############################################################
|
||||
aws_secret_manager_client = AWSSecretsManagerV2()
|
||||
litellm.secret_manager_client = aws_secret_manager_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
litellm._key_management_settings = KeyManagementSettings(
|
||||
store_virtual_keys=True,
|
||||
)
|
||||
general_settings = {
|
||||
"key_management_system": "aws_secret_manager",
|
||||
"key_management_settings": {
|
||||
"store_virtual_keys": True,
|
||||
},
|
||||
}
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
############################################################################
|
||||
|
||||
# generate new key
|
||||
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
|
||||
spend = 100
|
||||
max_budget = 400
|
||||
models = ["fake-openai-endpoint"]
|
||||
new_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(
|
||||
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
||||
),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
generated_key = new_key.key
|
||||
print(generated_key)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# read from the secret manager
|
||||
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||
|
||||
# Assert the correct key is stored in the secret manager
|
||||
print("response from AWS Secret Manager")
|
||||
print(result)
|
||||
assert result == generated_key
|
||||
|
||||
# delete the key
|
||||
await delete_key_fn(
|
||||
data=KeyRequest(keys=[generated_key]),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
|
||||
),
|
||||
)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Assert the key is deleted from the secret manager
|
||||
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||
assert result is None
|
||||
|
||||
# cleanup
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||
|
||||
|
||||
################################################################################
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue