refactor secret managers

This commit is contained in:
Ishaan Jaff 2024-09-03 10:58:02 -07:00
parent 150f3c2cfa
commit 3c898e23ea
11 changed files with 22 additions and 18 deletions

View file

@ -9,7 +9,7 @@ import time
sys.path.insert( sys.path.insert(
0, os.path.abspath("./") 0, os.path.abspath("./")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var from litellm.secret_managers.aws_secret_manager import decrypt_env_var
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True": if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV ## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV

View file

@ -475,7 +475,7 @@ def run_server(
### DECRYPT ENV VAR ### ### DECRYPT ENV VAR ###
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var from litellm.secret_managers.aws_secret_manager import decrypt_env_var
if ( if (
os.getenv("USE_AWS_KMS", None) is not None os.getenv("USE_AWS_KMS", None) is not None
@ -548,7 +548,7 @@ def run_server(
key_management_system key_management_system
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
): ):
from litellm.proxy.secret_managers.google_secret_manager import ( from litellm.secret_managers.google_secret_manager import (
GoogleSecretManager, GoogleSecretManager,
) )

View file

@ -212,11 +212,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
) )
from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
from litellm.proxy.route_llm_request import route_request from litellm.proxy.route_llm_request import route_request
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
from litellm.proxy.spend_tracking.spend_management_endpoints import ( from litellm.proxy.spend_tracking.spend_management_endpoints import (
router as spend_management_router, router as spend_management_router,
) )
@ -257,6 +252,11 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicResponse, AnthropicResponse,
@ -1769,7 +1769,7 @@ class ProxyConfig:
key_management_system key_management_system
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
): ):
from litellm.proxy.secret_managers.google_secret_manager import ( from litellm.secret_managers.google_secret_manager import (
GoogleSecretManager, GoogleSecretManager,
) )

View file

@ -9,7 +9,7 @@ import openai
import litellm import litellm
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.proxy.secret_managers.get_azure_ad_token_provider import ( from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider, get_azure_ad_token_provider,
) )
from litellm.utils import calculate_max_parallel_requests from litellm.utils import calculate_max_parallel_requests

View file

@ -14,8 +14,7 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
Returns: Returns:
Callable that returns a temporary authentication token. Callable that returns a temporary authentication token.
""" """
from azure.identity import ClientSecretCredential from azure.identity import ClientSecretCredential, get_bearer_token_provider
from azure.identity import get_bearer_token_provider
try: try:
credential = ClientSecretCredential( credential = ClientSecretCredential(
@ -24,7 +23,9 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
tenant_id=os.environ["AZURE_TENANT_ID"], tenant_id=os.environ["AZURE_TENANT_ID"],
) )
except KeyError as e: except KeyError as e:
raise ValueError("Missing environment variable required by Azure AD workflow.") from e raise ValueError(
"Missing environment variable required by Azure AD workflow."
) from e
return get_bearer_token_provider( return get_bearer_token_provider(
credential, credential,

View file

@ -7,8 +7,11 @@ Requires:
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]` * `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
* `pip install google-cloud-kms` * `pip install google-cloud-kms`
""" """
import litellm, os
import os
from typing import Optional from typing import Optional
import litellm
from litellm.proxy._types import KeyManagementSystem from litellm.proxy._types import KeyManagementSystem

View file

@ -83,7 +83,7 @@ async def test_router_init():
) )
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os") @patch("litellm.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret( def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
mocked_os_lib: MagicMock, mocked_os_lib: MagicMock,
) -> None: ) -> None:
@ -128,7 +128,7 @@ def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secre
@patch("azure.identity.get_bearer_token_provider") @patch("azure.identity.get_bearer_token_provider")
@patch("azure.identity.ClientSecretCredential") @patch("azure.identity.ClientSecretCredential")
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os") @patch("litellm.secret_managers.get_azure_ad_token_provider.os")
def test_router_init_azure_service_principal_with_secret_with_environment_variables( def test_router_init_azure_service_principal_with_secret_with_environment_variables(
mocked_os_lib: MagicMock, mocked_os_lib: MagicMock,
mocked_credential: MagicMock, mocked_credential: MagicMock,

View file

@ -18,7 +18,7 @@ import pytest
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret

View file

@ -17,7 +17,7 @@ def test_decrypt_and_reset_env():
os.environ["DATABASE_URL"] = ( os.environ["DATABASE_URL"] = (
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La" "aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
) )
from litellm.proxy.secret_managers.aws_secret_manager import ( from litellm.secret_managers.aws_secret_manager import (
decrypt_and_reset_env_var, decrypt_and_reset_env_var,
) )