(Feat) Add support for storing virtual keys in AWS SecretManager (#6728)

* add SecretManager to httpxSpecialProvider

* fix importing AWSSecretsManagerV2

* add unit testing for writing keys to AWS secret manager

* use KeyManagementEventHooks for key/generated events

* us event hooks for key management endpoints

* working AWSSecretsManagerV2

* fix write secret to AWS secret manager on /key/generate

* fix KeyManagementSettings

* use tasks for key management hooks

* add async_delete_secret

* add test for async_delete_secret

* use _delete_virtual_keys_from_secret_manager

* fix test secret manager

* test_key_generate_with_secret_manager_call

* fix check for key_management_settings

* sync_read_secret

* test_aws_secret_manager

* fix sync_read_secret

* use helper to check when _should_read_secret_from_secret_manager

* test_get_secret_with_access_mode

* test - handle eol model claude-2, use claude-2.1 instead

* docs AWS secret manager

* fix test_read_nonexistent_secret

* fix test_supports_response_schema

* ci/cd run again
This commit is contained in:
Ishaan Jaff 2024-11-14 09:25:07 -08:00 committed by GitHub
parent da84056e59
commit f8e700064e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1046 additions and 178 deletions

View file

@ -0,0 +1,139 @@
# What is this?
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io
import sys
import os
# Ensure the project root is in the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
print("Python Path:", sys.path)
print("Current Working Directory:", os.getcwd())
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import uuid
import json
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
def check_aws_credentials():
"""Helper function to check if AWS credentials are set"""
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
@pytest.mark.asyncio
async def test_write_and_read_simple_secret():
"""Test writing and reading a simple string secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
test_secret_value = "test_value_123"
try:
# Write secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=test_secret_value,
description="LiteLLM Test Secret",
)
print("Write Response:", write_response)
assert write_response is not None
assert "ARN" in write_response
assert "Name" in write_response
assert write_response["Name"] == test_secret_name
# Read secret back
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
print("Read Value:", read_value)
assert read_value == test_secret_value
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_write_and_read_json_secret():
"""Test writing and reading a JSON structured secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
test_secret_value = {
"api_key": "test_key",
"model": "gpt-4",
"temperature": 0.7,
"metadata": {"team": "ml", "project": "litellm"},
}
try:
# Write JSON secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=json.dumps(test_secret_value),
description="LiteLLM JSON Test Secret",
)
print("Write Response:", write_response)
# Read and parse JSON secret
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
parsed_value = json.loads(read_value)
print("Read Value:", read_value)
assert parsed_value == test_secret_value
assert parsed_value["api_key"] == "test_key"
assert parsed_value["metadata"]["team"] == "ml"
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_read_nonexistent_secret():
"""Test reading a secret that doesn't exist"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
assert response is None

View file

@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries = 3
# litellm.num_retries=3
litellm.cache = None
litellm.success_callback = []

View file

@ -15,22 +15,29 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.secret_managers.main import (
get_secret,
_should_read_secret_from_secret_manager,
)
@pytest.mark.skip(reason="AWS Suspended Account")
def test_aws_secret_manager():
load_aws_secret_manager(use_aws_secret_manager=True)
import json
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234"
# cast json to dict
secret_val = json.loads(secret_val)
assert secret_val["litellm_master_key"] == "sk-1234"
def redact_oidc_signature(secret_val):
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
)
print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm"
def test_should_read_secret_from_secret_manager():
"""
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
"""
from litellm.proxy._types import KeyManagementSettings
# Test when secret manager client is None
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is False
# Test with secret manager client and read_only access
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and write_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert _should_read_secret_from_secret_manager() is False
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
def test_get_secret_with_access_mode():
"""
Test that get_secret respects access mode settings
"""
from litellm.proxy._types import KeyManagementSettings
# Set up test environment
test_secret_name = "TEST_SECRET_KEY"
test_secret_value = "test_secret_value"
os.environ[test_secret_name] = test_secret_value
# Test with write_only access (should read from os.environ)
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert get_secret(test_secret_name) == test_secret_value
# Test with no KeyManagementSettings but secret_manager_client set
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is True
# Test with read_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
del os.environ[test_secret_name]