diff --git a/.circleci/config.yml b/.circleci/config.yml index 032f697c78..f59cbef5a5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -671,6 +671,51 @@ jobs: paths: - batches_coverage.xml - batches_coverage + secret_manager_testing: + docker: + - image: cimg/python:3.11 + auth: + username: ${DOCKERHUB_USERNAME} + password: ${DOCKERHUB_PASSWORD} + working_directory: ~/project + + steps: + - checkout + - run: + name: Install Dependencies + command: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + pip install "respx==0.21.1" + pip install "pytest==7.3.1" + pip install "pytest-retry==1.6.3" + pip install "pytest-asyncio==0.21.1" + pip install "pytest-cov==5.0.0" + pip install "google-generativeai==0.3.2" + pip install "google-cloud-aiplatform==1.43.0" + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/secret_manager_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m + - run: + name: Rename the coverage files + command: | + mv coverage.xml secret_manager_coverage.xml + mv .coverage secret_manager_coverage + + # Store test results + - store_test_results: + path: test-results + - persist_to_workspace: + root: . + paths: + - secret_manager_coverage.xml + - secret_manager_coverage + pass_through_unit_testing: docker: - image: cimg/python:3.11 @@ -1767,6 +1812,12 @@ workflows: only: - main - /litellm_.*/ + - secret_manager_testing: + filters: + branches: + only: + - main + - /litellm_.*/ - pass_through_unit_testing: filters: branches: @@ -1789,6 +1840,7 @@ workflows: requires: - llm_translation_testing - batches_testing + - secret_manager_testing - pass_through_unit_testing - image_gen_testing - logging_testing @@ -1838,6 +1890,7 @@ workflows: - test_bad_database_url - llm_translation_testing - batches_testing + - secret_manager_testing - pass_through_unit_testing - image_gen_testing - logging_testing diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 4c54aebc66..4b323602a6 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -390,6 +390,9 @@ router_settings: | GOOGLE_CLIENT_SECRET | Client secret for Google OAuth | GOOGLE_KMS_RESOURCE_NAME | Name of the resource in Google KMS | HF_API_BASE | Base URL for Hugging Face API +| HCP_VAULT_ADDR | Address for [Hashicorp Vault Secret Manager](../secret.md#hashicorp-vault) +| HCP_VAULT_NAMESPACE | Namespace for [Hashicorp Vault Secret Manager](../secret.md#hashicorp-vault) +| HCP_VAULT_TOKEN | Token for [Hashicorp Vault Secret Manager](../secret.md#hashicorp-vault) | HELICONE_API_KEY | API key for Helicone service | HOSTNAME | Hostname for the server, this will be [emitted to `datadog` logs](https://docs.litellm.ai/docs/proxy/logging#datadog) | HUGGINGFACE_API_BASE | Base URL for Hugging Face API diff --git a/docs/my-website/docs/secret.md b/docs/my-website/docs/secret.md index 113a11750b..3317c320eb 100644 --- a/docs/my-website/docs/secret.md +++ b/docs/my-website/docs/secret.md @@ -1,5 +1,6 @@ import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; +import Image from '@theme/IdealImage'; # Secret Manager LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager @@ -21,6 +22,7 @@ LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager - [Azure Key Vault](#azure-key-vault) - [Google Secret Manager](#google-secret-manager) - Google Key Management Service +- [Hashicorp Vault](#hashicorp-vault) - [Infisical Secret Manager](#infisical-secret-manager) - [.env Files](#env-files) @@ -52,7 +54,7 @@ general_settings: Store your proxy keys in AWS Secret Manager. -### Proxy Usage +#### Proxy Usage 1. Save AWS Credentials in your environment ```bash @@ -128,7 +130,7 @@ litellm.secret_manager = client litellm.get_secret("your-test-key") ``` --> -### Usage with LiteLLM Proxy Server +#### Usage with LiteLLM Proxy Server 1. Install Proxy dependencies ```bash @@ -233,12 +235,73 @@ And in another terminal $ litellm --test ``` -[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js) - +[Quick Test Proxy](./proxy/user_keys) +## Hashicorp Vault + +Read secrets from [Hashicorp Vault](https://developer.hashicorp.com/vault/docs/secrets/kv/kv-v2) + +Step 1. Add Hashicorp Vault details in your environment + +```bash +HCP_VAULT_ADDR="https://test-cluster-public-vault-0f98180c.e98296b2.z1.hashicorp.cloud:8200" +HCP_VAULT_NAMESPACE="admin" +HCP_VAULT_TOKEN="hvs.CAESIG52gL6ljBSdmq*****" + +# OPTIONAL +HCP_VAULT_REFRESH_INTERVAL="86400" # defaults to 86400, frequency of cache refresh for Hashicorp Vault +``` + +Step 2. Add to proxy config.yaml + +```yaml +general_settings: + key_management_system: "hashicorp_vault" +``` + +Step 3. Start + test proxy + +``` +$ litellm --config /path/to/config.yaml +``` + +[Quick Test Proxy](./proxy/user_keys) + + +#### How it works + +LiteLLM reads secrets from Hashicorp Vault's KV v2 engine using the following URL format: +``` +{VAULT_ADDR}/v1/{NAMESPACE}/secret/data/{SECRET_NAME} +``` + +For example, if you have: +- `HCP_VAULT_ADDR="https://vault.example.com:8200"` +- `HCP_VAULT_NAMESPACE="admin"` +- Secret name: `AZURE_API_KEY` + + +LiteLLM will look up: +``` +https://vault.example.com:8200/v1/admin/secret/data/AZURE_API_KEY +``` + +#### Expected Secret Format +LiteLLM expects all secrets to be stored as a JSON object with a `key` field containing the secret value. + +For example, for `AZURE_API_KEY`, the secret should be stored as: + +```json +{ + "key": "sk-1234" +} +``` + + + ## All Secret Manager Settings diff --git a/docs/my-website/img/hcorp.png b/docs/my-website/img/hcorp.png new file mode 100644 index 0000000000..6d8b309d75 Binary files /dev/null and b/docs/my-website/img/hcorp.png differ diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a9a9db294c..3a6ad4a922 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1146,6 +1146,7 @@ class KeyManagementSystem(enum.Enum): AZURE_KEY_VAULT = "azure_key_vault" AWS_SECRET_MANAGER = "aws_secret_manager" GOOGLE_SECRET_MANAGER = "google_secret_manager" + HASHICORP_VAULT = "hashicorp_vault" LOCAL = "local" AWS_KMS = "aws_kms" diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 5c4b04fb70..f210bd2dc8 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -257,24 +257,16 @@ def run_server( # noqa: PLR0915 if local: from proxy_server import ( KeyManagementSettings, - KeyManagementSystem, ProxyConfig, app, - load_aws_kms, - load_from_azure_key_vault, - load_google_kms, save_worker_config, ) else: try: from .proxy_server import ( KeyManagementSettings, - KeyManagementSystem, ProxyConfig, app, - load_aws_kms, - load_from_azure_key_vault, - load_google_kms, save_worker_config, ) except ImportError as e: @@ -285,12 +277,8 @@ def run_server( # noqa: PLR0915 # this is just a local/relative import error, user git cloned litellm from proxy_server import ( KeyManagementSettings, - KeyManagementSystem, ProxyConfig, app, - load_aws_kms, - load_from_azure_key_vault, - load_google_kms, save_worker_config, ) if version is True: @@ -537,41 +525,7 @@ def run_server( # noqa: PLR0915 key_management_system = general_settings.get( "key_management_system", None ) - if key_management_system is not None: - if ( - key_management_system - == KeyManagementSystem.AZURE_KEY_VAULT.value - ): - ### LOAD FROM AZURE KEY VAULT ### - load_from_azure_key_vault(use_azure_key_vault=True) - elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: - ### LOAD FROM GOOGLE KMS ### - load_google_kms(use_google_kms=True) - elif ( - key_management_system - == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 - ): - from litellm.secret_managers.aws_secret_manager_v2 import ( - AWSSecretsManagerV2, - ) - - ### LOAD FROM AWS SECRET MANAGER ### - AWSSecretsManagerV2.load_aws_secret_manager( - use_aws_secret_manager=True - ) - elif key_management_system == KeyManagementSystem.AWS_KMS.value: - load_aws_kms(use_aws_kms=True) - elif ( - key_management_system - == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value - ): - from litellm.secret_managers.google_secret_manager import ( - GoogleSecretManager, - ) - - GoogleSecretManager() - else: - raise ValueError("Invalid Key Management System selected") + proxy_config.initialize_secret_manager(key_management_system) key_management_settings = general_settings.get( "key_management_settings", None ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index ea0120178c..6c375752dd 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,4 +1,8 @@ model_list: + - model_name: "openai/*" + litellm_params: + model: "openai/*" + api_key: os.environ/OPENAI_API_KEY - model_name: "azure/*" litellm_params: model: azure/chatgpt-v-2 @@ -19,7 +23,7 @@ general_settings: # Requests hanging threshold hanging_threshold_seconds: 0.0000001 # Number of seconds of waiting for a response before a request is considered hanging hanging_threshold_window_seconds: 10 # Window in seconds - + key_management_system: "hashicorp_vault" # For /fine_tuning/jobs endpoints finetune_settings: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d0d6621cca..2760ad9f7a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1894,37 +1894,7 @@ class ProxyConfig: if general_settings: ### LOAD SECRET MANAGER ### key_management_system = general_settings.get("key_management_system", None) - if key_management_system is not None: - if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: - ### LOAD FROM AZURE KEY VAULT ### - load_from_azure_key_vault(use_azure_key_vault=True) - elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: - ### LOAD FROM GOOGLE KMS ### - load_google_kms(use_google_kms=True) - elif ( - key_management_system - == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 - ): - from litellm.secret_managers.aws_secret_manager_v2 import ( - AWSSecretsManagerV2, - ) - - AWSSecretsManagerV2.load_aws_secret_manager( - use_aws_secret_manager=True - ) - elif key_management_system == KeyManagementSystem.AWS_KMS.value: - load_aws_kms(use_aws_kms=True) - elif ( - key_management_system - == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value - ): - from litellm.secret_managers.google_secret_manager import ( - GoogleSecretManager, - ) - - GoogleSecretManager() - else: - raise ValueError("Invalid Key Management System selected") + self.initialize_secret_manager(key_management_system=key_management_system) key_management_settings = general_settings.get( "key_management_settings", None ) @@ -2167,6 +2137,45 @@ class ProxyConfig: litellm.callbacks.append(_logger) pass + def initialize_secret_manager(self, key_management_system: Optional[str]): + """ + Initialize the relevant secret manager if `key_management_system` is provided + """ + if key_management_system is not None: + if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: + ### LOAD FROM AZURE KEY VAULT ### + load_from_azure_key_vault(use_azure_key_vault=True) + elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: + ### LOAD FROM GOOGLE KMS ### + load_google_kms(use_google_kms=True) + elif ( + key_management_system + == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 + ): + from litellm.secret_managers.aws_secret_manager_v2 import ( + AWSSecretsManagerV2, + ) + + AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True) + elif key_management_system == KeyManagementSystem.AWS_KMS.value: + load_aws_kms(use_aws_kms=True) + elif ( + key_management_system == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value + ): + from litellm.secret_managers.google_secret_manager import ( + GoogleSecretManager, + ) + + GoogleSecretManager() + elif key_management_system == KeyManagementSystem.HASHICORP_VAULT.value: + from litellm.secret_managers.hashicorp_secret_manager import ( + HashicorpSecretManager, + ) + + HashicorpSecretManager() + else: + raise ValueError("Invalid Key Management System selected") + def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo: """ Common logic across add + delete router models diff --git a/litellm/secret_managers/hashicorp_secret_manager.py b/litellm/secret_managers/hashicorp_secret_manager.py new file mode 100644 index 0000000000..d97752da29 --- /dev/null +++ b/litellm/secret_managers/hashicorp_secret_manager.py @@ -0,0 +1,138 @@ +import os +from typing import Optional + +import litellm +from litellm._logging import verbose_logger +from litellm.caching import InMemoryCache +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.proxy._types import KeyManagementSystem + + +class HashicorpSecretManager: + def __init__(self): + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + # Vault-specific config + self.vault_addr = os.getenv("HCP_VAULT_ADDR", "http://127.0.0.1:8200") + self.vault_token = os.getenv("HCP_VAULT_TOKEN", "") + # If your KV engine is mounted somewhere other than "secret", adjust here: + self.vault_namespace = os.getenv("HCP_VAULT_NAMESPACE", None) + + # Validate environment + if not self.vault_token: + raise ValueError( + "Missing Vault token. Please set VAULT_TOKEN in your environment." + ) + + litellm.secret_manager_client = self + litellm._key_management_system = KeyManagementSystem.HASHICORP_VAULT + _refresh_interval = os.environ.get("HCP_VAULT_REFRESH_INTERVAL", 86400) + _refresh_interval = int(_refresh_interval) if _refresh_interval else 86400 + self.cache = InMemoryCache( + default_ttl=_refresh_interval + ) # store in memory for 1 day + + if premium_user is not True: + raise ValueError( + f"Hashicorp secret manager is only available for premium users. {CommonProxyErrors.not_premium_user.value}" + ) + + def get_url(self, secret_name: str) -> str: + _url = f"{self.vault_addr}/v1/" + if self.vault_namespace: + _url += f"{self.vault_namespace}/" + _url += f"secret/data/{secret_name}" + return _url + + async def async_read_secret(self, secret_name: str) -> Optional[str]: + """ + Reads a secret from Vault KV v2 using an async HTTPX client. + secret_name is just the path inside the KV mount (e.g., 'myapp/config'). + Returns the entire data dict from data.data, or None on failure. + """ + if self.cache.get_cache(secret_name) is not None: + return self.cache.get_cache(secret_name) + async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.SecretManager, + ) + try: + # For KV v2: /v1//data/ + # Example: http://127.0.0.1:8200/v1/secret/data/myapp/config + _url = self.get_url(secret_name) + url = _url + + response = await async_client.get( + url, headers={"X-Vault-Token": self.vault_token} + ) + response.raise_for_status() + + # For KV v2, the secret is in response.json()["data"]["data"] + json_resp = response.json() + _value = self._get_secret_value_from_json_response(json_resp) + self.cache.set_cache(secret_name, _value) + return _value + + except Exception as e: + verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}") + return None + + def read_secret(self, secret_name: str) -> Optional[str]: + """ + Reads a secret from Vault KV v2 using a sync HTTPX client. + secret_name is just the path inside the KV mount (e.g., 'myapp/config'). + Returns the entire data dict from data.data, or None on failure. + """ + if self.cache.get_cache(secret_name) is not None: + return self.cache.get_cache(secret_name) + sync_client = _get_httpx_client() + try: + # For KV v2: /v1//data/ + url = self.get_url(secret_name) + + response = sync_client.get(url, headers={"X-Vault-Token": self.vault_token}) + response.raise_for_status() + + # For KV v2, the secret is in response.json()["data"]["data"] + json_resp = response.json() + verbose_logger.debug(f"Hashicorp secret manager response: {json_resp}") + _value = self._get_secret_value_from_json_response(json_resp) + self.cache.set_cache(secret_name, _value) + return _value + + except Exception as e: + verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}") + return None + + def _get_secret_value_from_json_response( + self, json_resp: Optional[dict] + ) -> Optional[str]: + """ + Get the secret value from the JSON response + + Json response from hashicorp vault is of the form: + + { + "request_id":"036ba77c-018b-31dd-047b-323bcd0cd332", + "lease_id":"", + "renewable":false, + "lease_duration":0, + "data": + {"data": + {"key":"Vault Is The Way"}, + "metadata":{"created_time":"2025-01-01T22:13:50.93942388Z","custom_metadata":null,"deletion_time":"","destroyed":false,"version":1} + }, + "wrap_info":null, + "warnings":null, + "auth":null, + "mount_type":"kv" + } + + Note: LiteLLM assumes that all secrets are stored as under the key "key" + """ + if json_resp is None: + return None + return json_resp.get("data", {}).get("data", {}).get("key", None) diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index 2b89aedadd..738332209a 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -289,6 +289,19 @@ def get_secret( # noqa: PLR0915 except Exception as e: print_verbose(f"An error occurred - {str(e)}") raise e + elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value: + try: + secret = client.read_secret(secret_name) + print_verbose( + f"secret from hashicorp secret manager: {secret}" + ) + if secret is None: + raise ValueError( + f"No secret found in Hashicorp Secret Manager for {secret_name}" + ) + except Exception as e: + print_verbose(f"An error occurred - {str(e)}") + raise e elif key_manager == "local": secret = os.getenv(secret_name) else: # assume the default is infisicial client diff --git a/tests/secret_manager_tests/conftest.py b/tests/secret_manager_tests/conftest.py new file mode 100644 index 0000000000..eca0bc431a --- /dev/null +++ b/tests/secret_manager_tests/conftest.py @@ -0,0 +1,54 @@ +# conftest.py + +import importlib +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +@pytest.fixture(scope="function", autouse=True) +def setup_and_teardown(): + """ + This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. + """ + curr_dir = os.getcwd() # Get the current working directory + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the project directory to the system path + + import litellm + from litellm import Router + + importlib.reload(litellm) + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) + print(litellm) + # from litellm import Router, completion, aembedding, acompletion, embedding + yield + + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + + +def pytest_collection_modifyitems(config, items): + # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests + custom_logger_tests = [ + item for item in items if "custom_logger" in item.parent.name + ] + other_tests = [item for item in items if "custom_logger" not in item.parent.name] + + # Sort tests based on their names + custom_logger_tests.sort(key=lambda x: x.name) + other_tests.sort(key=lambda x: x.name) + + # Reorder the items list + items[:] = custom_logger_tests + other_tests diff --git a/tests/local_testing/test_aws_secret_manager.py b/tests/secret_manager_tests/test_aws_secret_manager.py similarity index 100% rename from tests/local_testing/test_aws_secret_manager.py rename to tests/secret_manager_tests/test_aws_secret_manager.py diff --git a/tests/local_testing/test_get_secret.py b/tests/secret_manager_tests/test_get_secret.py similarity index 100% rename from tests/local_testing/test_get_secret.py rename to tests/secret_manager_tests/test_get_secret.py diff --git a/tests/secret_manager_tests/test_hashicorp.py b/tests/secret_manager_tests/test_hashicorp.py new file mode 100644 index 0000000000..3d3f96bd2f --- /dev/null +++ b/tests/secret_manager_tests/test_hashicorp.py @@ -0,0 +1,67 @@ +import os +import sys +import pytest +from dotenv import load_dotenv + +load_dotenv() +import os +import httpx + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from unittest.mock import patch, MagicMock +import logging +from litellm._logging import verbose_logger + +verbose_logger.setLevel(logging.DEBUG) + +from litellm.secret_managers.hashicorp_secret_manager import HashicorpSecretManager + +hashicorp_secret_manager = HashicorpSecretManager() + + +mock_vault_response = { + "request_id": "80fafb6a-e96a-4c5b-29fa-ff505ac72201", + "lease_id": "", + "renewable": False, + "lease_duration": 0, + "data": { + "data": {"key": "value-mock"}, + "metadata": { + "created_time": "2025-01-01T22:13:50.93942388Z", + "custom_metadata": None, + "deletion_time": "", + "destroyed": False, + "version": 1, + }, + }, + "wrap_info": None, + "warnings": None, + "auth": None, + "mount_type": "kv", +} + + +def test_hashicorp_secret_manager_get_secret(): + with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.get") as mock_get: + # Configure the mock response using MagicMock + mock_response = MagicMock() + mock_response.json.return_value = mock_vault_response + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Test the secret manager + secret = hashicorp_secret_manager.read_secret("sample-secret-mock") + assert secret == "value-mock" + + # Verify the request was made with correct parameters + mock_get.assert_called_once() + called_url = mock_get.call_args[0][0] + assert "sample-secret-mock" in called_url + + assert ( + called_url + == "https://test-cluster-public-vault-0f98180c.e98296b2.z1.hashicorp.cloud:8200/v1/admin/secret/data/sample-secret-mock" + ) + assert "X-Vault-Token" in mock_get.call_args.kwargs["headers"] diff --git a/tests/local_testing/test_secret_manager.py b/tests/secret_manager_tests/test_secret_manager.py similarity index 85% rename from tests/local_testing/test_secret_manager.py rename to tests/secret_manager_tests/test_secret_manager.py index f4fb1b450a..6d5e3a1aff 100644 --- a/tests/local_testing/test_secret_manager.py +++ b/tests/secret_manager_tests/test_secret_manager.py @@ -5,6 +5,7 @@ import traceback import uuid from dotenv import load_dotenv +import json load_dotenv() import os @@ -25,6 +26,46 @@ from litellm.secret_managers.main import ( ) +def load_vertex_ai_credentials(): + # Define the path to the vertex_key.json file + print("loading vertex ai credentials") + filepath = os.path.dirname(os.path.abspath(__file__)) + vertex_key_path = filepath + "/vertex_key.json" + + # Read the existing content of the file or create an empty dictionary + try: + with open(vertex_key_path, "r") as file: + # Read the file content + print("Read vertexai file path") + content = file.read() + + # If the file is empty or not valid JSON, create an empty dictionary + if not content or not content.strip(): + service_account_key_data = {} + else: + # Attempt to load the existing JSON content + file.seek(0) + service_account_key_data = json.load(file) + except FileNotFoundError: + # If the file doesn't exist, create an empty dictionary + service_account_key_data = {} + + # Update the service_account_key_data with environment variables + private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "") + private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "") + private_key = private_key.replace("\\n", "\n") + service_account_key_data["private_key_id"] = private_key_id + service_account_key_data["private_key"] = private_key + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + # Write the updated content to the temporary files + json.dump(service_account_key_data, temp_file, indent=2) + + # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) + + def test_aws_secret_manager(): import json @@ -204,7 +245,6 @@ def test_google_secret_manager(): Test that we can get a secret from Google Secret Manager """ os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "adroit-crow-413218" - from test_amazing_vertex_completion import load_vertex_ai_credentials from litellm.secret_managers.google_secret_manager import GoogleSecretManager @@ -227,8 +267,6 @@ def test_google_secret_manager_read_in_memory(): """ Test that Google Secret manager returs in memory value when it exists """ - from test_amazing_vertex_completion import load_vertex_ai_credentials - from litellm.secret_managers.google_secret_manager import GoogleSecretManager load_vertex_ai_credentials() diff --git a/tests/secret_manager_tests/vertex_key.json b/tests/secret_manager_tests/vertex_key.json new file mode 100644 index 0000000000..e2fd8512b1 --- /dev/null +++ b/tests/secret_manager_tests/vertex_key.json @@ -0,0 +1,13 @@ +{ + "type": "service_account", + "project_id": "adroit-crow-413218", + "private_key_id": "", + "private_key": "", + "client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", + "client_id": "104886546564708740969", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +}