From 368fee224ed799a12782fa7d1493935b4e1b0c37 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 1 Dec 2023 19:36:06 -0800 Subject: [PATCH] feat: support for azure key vault --- litellm/__init__.py | 7 +++--- litellm/proxy/proxy_server.py | 41 +++++++++++++++++++++++++++++++++++ litellm/router.py | 6 ++--- litellm/utils.py | 26 ++++++++++++---------- pyproject.toml | 6 +++++ requirements.txt | 1 + 6 files changed, 70 insertions(+), 17 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 121acda72..5fd54401b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -58,6 +58,8 @@ num_retries: Optional[int] = None fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None allowed_fails: int = 0 +####### SECRET MANAGERS ##################### +secret_manager_client = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. ############################################# def get_model_cost_map(url: str): @@ -95,8 +97,6 @@ headers = None api_version = None organization = None config_path = None -####### Secret Manager ##################### -secret_manager_client = None ####### COMPLETION MODELS ################### open_ai_chat_completion_models: List = [] open_ai_text_completion_models: List = [] @@ -366,7 +366,8 @@ from .utils import ( encode, decode, _calculate_retry_after, - _should_retry + _should_retry, + get_secret ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0e97c3189..b5b770295 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -286,6 +286,44 @@ def celery_setup(use_queue: bool): async_result = AsyncResult celery_app_conn = celery_app +def load_from_azure_key_vault(use_azure_key_vault: bool = False): + if use_azure_key_vault is False: + return + + try: + from azure.keyvault.secrets import SecretClient + from azure.identity import ClientSecretCredential + + # Set your Azure Key Vault URI + KVUri = os.getenv("AZURE_KEY_VAULT_URI") + + # Set your Azure AD application/client ID, client secret, and tenant ID + client_id = os.getenv("AZURE_CLIENT_ID") + client_secret = os.getenv("AZURE_CLIENT_SECRET") + tenant_id = os.getenv("AZURE_TENANT_ID") + + # Initialize the ClientSecretCredential + credential = ClientSecretCredential(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id) + + # Create the SecretClient using the credential + client = SecretClient(vault_url=KVUri, credential=credential) + + litellm.secret_manager_client = client + # # Retrieve all secrets + # secrets = client.get_secrets() + + # # Load secrets into environment variables + # for secret in secrets: + # secret_name = secret.name + # secret_value = client.get_secret(secret_name).value + # os.environ[secret_name] = secret_value + + print(f"test key - : {litellm.get_secret('test-3')}") + except Exception as e: + print(e) + print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`") + + def cost_tracking(): global prisma_client, master_key if prisma_client is not None and master_key is not None: @@ -412,6 +450,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ### START REDIS QUEUE ### use_queue = general_settings.get("use_queue", False) celery_setup(use_queue=use_queue) + ### LOAD FROM AZURE KEY VAULT ### + use_azure_key_vault = general_settings.get("use_azure_key_vault", False) + load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) diff --git a/litellm/router.py b/litellm/router.py index 04d6ead90..fad67cf6d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -844,19 +844,19 @@ class Router: api_key = litellm_params.get("api_key") if api_key and api_key.startswith("os.environ/"): api_key_env_name = api_key.replace("os.environ/", "") - api_key = os.getenv(api_key_env_name) + api_key = litellm.get_secret(api_key_env_name) api_base = litellm_params.get("api_base") base_url = litellm_params.get("base_url") api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): api_base_env_name = api_base.replace("os.environ/", "") - api_base = os.getenv(api_base_env_name) + api_base = litellm.get_secret(api_base_env_name) api_version = litellm_params.get("api_version") if api_version and api_version.startswith("os.environ/"): api_version_env_name = api_version.replace("os.environ/", "") - api_version = os.getenv(api_version_env_name) + api_version = litellm.get_secret(api_version_env_name) self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}") if "azure" in model_name: if api_version is None: diff --git a/litellm/utils.py b/litellm/utils.py index 6fb68d481..4ae862314 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2421,7 +2421,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ if api_key and api_key.startswith("os.environ/"): api_key_env_name = api_key.replace("os.environ/", "") - dynamic_api_key = os.getenv(api_key_env_name) + dynamic_api_key = get_secret(api_key_env_name) # check if llm provider part of model name if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: custom_llm_provider = model.split("/", 1)[0] @@ -2429,15 +2429,15 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ if custom_llm_provider == "perplexity": # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai api_base = "https://api.perplexity.ai" - dynamic_api_key = os.getenv("PERPLEXITYAI_API_KEY") + dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") elif custom_llm_provider == "anyscale": # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.endpoints.anyscale.com/v1" - dynamic_api_key = os.getenv("ANYSCALE_API_KEY") + dynamic_api_key = get_secret("ANYSCALE_API_KEY") elif custom_llm_provider == "deepinfra": # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 api_base = "https://api.deepinfra.com/v1/openai" - dynamic_api_key = os.getenv("DEEPINFRA_API_KEY") + dynamic_api_key = get_secret("DEEPINFRA_API_KEY") return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint @@ -2446,13 +2446,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ if endpoint in api_base: if endpoint == "api.perplexity.ai": custom_llm_provider = "perplexity" - dynamic_api_key = os.getenv("PERPLEXITYAI_API_KEY") + dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") elif endpoint == "api.endpoints.anyscale.com/v1": custom_llm_provider = "anyscale" - dynamic_api_key = os.getenv("ANYSCALE_API_KEY") + dynamic_api_key = get_secret("ANYSCALE_API_KEY") elif endpoint == "api.deepinfra.com/v1/openai": custom_llm_provider = "deepinfra" - dynamic_api_key = os.getenv("DEEPINFRA_API_KEY") + dynamic_api_key = get_secret("DEEPINFRA_API_KEY") return model, custom_llm_provider, dynamic_api_key, api_base # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) @@ -4715,13 +4715,17 @@ def litellm_telemetry(data): # checks if user has passed in a secret manager client # if passed in then checks the secret there def get_secret(secret_name): - if litellm.secret_manager_client != None: + if litellm.secret_manager_client is not None: # TODO: check which secret manager is being used # currently only supports Infisical try: - secret = litellm.secret_manager_client.get_secret(secret_name).secret_value - except: - secret = None + client = litellm.secret_manager_client + if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + secret = retrieved_secret = client.get_secret(secret_name).value + else: # assume the default is infisicial client + secret = client.get_secret(secret_name).secret_value + except: # check if it's in os.environ + secret = os.environ.get(secret_name) return secret else: return os.environ.get(secret_name) diff --git a/pyproject.toml b/pyproject.toml index 2a413ae55..93c66e22f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,12 @@ proxy = [ "orjson", ] +extra_proxy = [ + "prisma", + "azure-identity", + "azure-keyvault-secrets" +] + [tool.poetry.scripts] litellm = 'litellm:run_server' diff --git a/requirements.txt b/requirements.txt index 7797f8f36..f153c58f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# LITELLM PROXY DEPENDENCIES # litellm openai fastapi