diff --git a/litellm/__init__.py b/litellm/__init__.py index 2c68fa9af3..5af751c3fe 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -834,7 +834,6 @@ from .utils import ( decode, _calculate_retry_after, _should_retry, - get_secret, get_supported_openai_params, get_api_base, get_first_chars_messages, diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 9489e09cbe..e927a18b66 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -22,6 +22,7 @@ import litellm from litellm import client from litellm.llms.azure import AzureBatchesAPI from litellm.llms.openai import OpenAIBatchesAPI +from litellm.secret_managers.main import get_secret from litellm.types.llms.openai import ( Batch, CancelBatchRequest, @@ -34,7 +35,7 @@ from litellm.types.llms.openai import ( RetrieveBatchRequest, ) from litellm.types.router import GenericLiteLLMParams -from litellm.utils import get_secret, supports_httpx_timeout +from litellm.utils import supports_httpx_timeout ####### ENVIRONMENT VARIABLES ################### openai_batches_instance = OpenAIBatchesAPI() diff --git a/litellm/fine_tuning/main.py b/litellm/fine_tuning/main.py index abf2828578..81b075f517 100644 --- a/litellm/fine_tuning/main.py +++ b/litellm/fine_tuning/main.py @@ -17,7 +17,6 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union import httpx import litellm -from litellm import get_secret from litellm._logging import verbose_logger from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI from litellm.llms.fine_tuning_apis.openai import ( @@ -26,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import ( OpenAIFineTuningAPI, ) from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI +from litellm.secret_managers.main import get_secret from litellm.types.llms.openai import Hyperparameters from litellm.types.router import * from litellm.utils import supports_httpx_timeout diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index 8de42eda73..7449dc2d7e 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -5,7 +5,7 @@ import httpx from litellm._logging import verbose_logger from litellm.caching import DualCache, InMemoryCache -from litellm.utils import get_secret +from litellm.secret_managers.main import get_secret from .base import BaseLLM diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 19a4f09860..f2032d110b 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -10,7 +10,7 @@ from typing import List, Optional, Union import httpx import litellm -from litellm import get_secret +from litellm.secret_managers.main import get_secret class BedrockError(Exception): diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index 6585ec4f2c..6398c2c341 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -11,7 +11,6 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union import httpx import litellm -from litellm import get_secret from litellm.llms.cohere.embed import embedding as cohere_embedding from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, @@ -19,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import ( _get_async_httpx_client, _get_httpx_client, ) +from litellm.secret_managers.main import get_secret from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest from litellm.types.utils import Embedding, EmbeddingResponse, Usage diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 39f65ac2df..e8ce3311f7 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1063,6 +1063,7 @@ class KeyManagementSystem(enum.Enum): GOOGLE_KMS = "google_kms" AZURE_KEY_VAULT = "azure_key_vault" AWS_SECRET_MANAGER = "aws_secret_manager" + GOOGLE_SECRET_MANAGER = "google_secret_manager" LOCAL = "local" AWS_KMS = "aws_kms" diff --git a/litellm/proxy/auth/rds_iam_token.py b/litellm/proxy/auth/rds_iam_token.py index ec3a424b9f..f836215845 100644 --- a/litellm/proxy/auth/rds_iam_token.py +++ b/litellm/proxy/auth/rds_iam_token.py @@ -14,7 +14,7 @@ def init_rds_client( aws_web_identity_token: Optional[str] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, ): - from litellm import get_secret + from litellm.secret_managers.main import get_secret # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) diff --git a/litellm/proxy/db/dynamo_db.py b/litellm/proxy/db/dynamo_db.py index 6056a61e25..09e14b4605 100644 --- a/litellm/proxy/db/dynamo_db.py +++ b/litellm/proxy/db/dynamo_db.py @@ -1,16 +1,17 @@ -from litellm.proxy.db.base_client import CustomDB -from litellm.proxy._types import ( - DynamoDBArgs, - LiteLLM_VerificationToken, - LiteLLM_Config, - LiteLLM_UserTable, -) -from litellm.proxy.utils import hash_token -from litellm import get_secret -from typing import Any, List, Literal, Optional, Union import json from datetime import datetime +from typing import Any, List, Literal, Optional, Union + from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + DynamoDBArgs, + LiteLLM_Config, + LiteLLM_UserTable, + LiteLLM_VerificationToken, +) +from litellm.proxy.db.base_client import CustomDB +from litellm.proxy.utils import hash_token +from litellm.secret_managers.main import get_secret class DynamoDBWrapper(CustomDB): @@ -21,19 +22,19 @@ class DynamoDBWrapper(CustomDB): def __init__(self, database_arguments: DynamoDBArgs): from aiodynamo.client import Client from aiodynamo.credentials import Credentials, StaticCredentials + from aiodynamo.expressions import F, UpdateExpression, Value + from aiodynamo.http.aiohttp import AIOHTTP from aiodynamo.http.httpx import HTTPX from aiodynamo.models import ( - Throughput, KeySchema, KeySpec, KeyType, PayPerRequest, + ReturnValues, + Throughput, ) - from yarl import URL - from aiodynamo.expressions import UpdateExpression, F, Value - from aiodynamo.models import ReturnValues - from aiodynamo.http.aiohttp import AIOHTTP from aiohttp import ClientSession + from yarl import URL self.throughput_type = None if database_arguments.billing_mode == "PAY_PER_REQUEST": @@ -59,7 +60,9 @@ class DynamoDBWrapper(CustomDB): verbose_proxy_logger.debug( f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}" ) - import boto3, os + import os + + import boto3 sts_client = boto3.client("sts") @@ -92,22 +95,22 @@ class DynamoDBWrapper(CustomDB): """ Connect to DB, and creating / updating any tables """ + import aiohttp from aiodynamo.client import Client from aiodynamo.credentials import Credentials, StaticCredentials + from aiodynamo.expressions import F, UpdateExpression, Value + from aiodynamo.http.aiohttp import AIOHTTP from aiodynamo.http.httpx import HTTPX from aiodynamo.models import ( - Throughput, KeySchema, KeySpec, KeyType, PayPerRequest, + ReturnValues, + Throughput, ) - from yarl import URL - from aiodynamo.expressions import UpdateExpression, F, Value - from aiodynamo.models import ReturnValues - from aiodynamo.http.aiohttp import AIOHTTP from aiohttp import ClientSession - import aiohttp + from yarl import URL verbose_proxy_logger.debug("DynamoDB Wrapper - Attempting to connect") self.set_env_vars_based_on_arn() @@ -192,22 +195,22 @@ class DynamoDBWrapper(CustomDB): async def insert_data( self, value: Any, table_name: Literal["user", "key", "config", "spend"] ): + import aiohttp from aiodynamo.client import Client from aiodynamo.credentials import Credentials, StaticCredentials + from aiodynamo.expressions import F, UpdateExpression, Value + from aiodynamo.http.aiohttp import AIOHTTP from aiodynamo.http.httpx import HTTPX from aiodynamo.models import ( - Throughput, KeySchema, KeySpec, KeyType, PayPerRequest, + ReturnValues, + Throughput, ) - from yarl import URL - from aiodynamo.expressions import UpdateExpression, F, Value - from aiodynamo.models import ReturnValues - from aiodynamo.http.aiohttp import AIOHTTP from aiohttp import ClientSession - import aiohttp + from yarl import URL self.set_env_vars_based_on_arn() @@ -237,22 +240,22 @@ class DynamoDBWrapper(CustomDB): return await table.put_item(item=value, return_values=ReturnValues.all_old) async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): + import aiohttp from aiodynamo.client import Client from aiodynamo.credentials import Credentials, StaticCredentials + from aiodynamo.expressions import F, UpdateExpression, Value + from aiodynamo.http.aiohttp import AIOHTTP from aiodynamo.http.httpx import HTTPX from aiodynamo.models import ( - Throughput, KeySchema, KeySpec, KeyType, PayPerRequest, + ReturnValues, + Throughput, ) - from yarl import URL - from aiodynamo.expressions import UpdateExpression, F, Value - from aiodynamo.models import ReturnValues - from aiodynamo.http.aiohttp import AIOHTTP from aiohttp import ClientSession - import aiohttp + from yarl import URL self.set_env_vars_based_on_arn() @@ -311,22 +314,22 @@ class DynamoDBWrapper(CustomDB): self, key: str, value: dict, table_name: Literal["user", "key", "config"] ): self.set_env_vars_based_on_arn() + import aiohttp from aiodynamo.client import Client from aiodynamo.credentials import Credentials, StaticCredentials + from aiodynamo.expressions import F, UpdateExpression, Value + from aiodynamo.http.aiohttp import AIOHTTP from aiodynamo.http.httpx import HTTPX from aiodynamo.models import ( - Throughput, KeySchema, KeySpec, KeyType, PayPerRequest, + ReturnValues, + Throughput, ) - from yarl import URL - from aiodynamo.expressions import UpdateExpression, F, Value - from aiodynamo.models import ReturnValues - from aiodynamo.http.aiohttp import AIOHTTP from aiohttp import ClientSession - import aiohttp + from yarl import URL if self.database_arguments.ssl_verify == False: client_session = ClientSession(connector=aiohttp.TCPConnector(ssl=False)) diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 01433b5559..eee26bd428 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -24,7 +24,6 @@ import httpx from fastapi import HTTPException import litellm -from litellm import get_secret from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail @@ -38,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import ( ) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.secret_managers.main import get_secret from litellm.types.guardrails import ( BedrockContentItem, BedrockRequest, diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 364bcb2227..5aebbcf3e7 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -19,12 +19,12 @@ import httpx from fastapi import HTTPException import litellm -from litellm import get_secret from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.secret_managers.main import get_secret from litellm.types.guardrails import ( GuardrailItem, LakeraCategoryThresholds, diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index cf2638f3c8..70f8e0e4af 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -544,6 +544,15 @@ def run_server( load_aws_secret_manager(use_aws_secret_manager=True) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) + elif ( + key_management_system + == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value + ): + from litellm.proxy.secret_managers.google_secret_manager import ( + GoogleSecretManager, + ) + + GoogleSecretManager() else: raise ValueError("Invalid Key Management System selected") key_management_settings = general_settings.get( @@ -598,7 +607,7 @@ def run_server( or os.getenv("DIRECT_URL", None) is not None ): try: - from litellm import get_secret + from litellm.secret_managers.main import get_secret if os.getenv("DATABASE_URL", None) is not None: ### add connection pool + pool timeout args diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index c32d7d755f..5a8a4cad41 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,15 +1,20 @@ model_list: - - model_name: fake-openai-endpoint - litellm_params: - model: openai/fake - api_key: fake-key - model_info: - id: "team-a-model" # used for identifying model in response headers + - model_name: fake-openai-endpoint + litellm_params: + model: openai/fake + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + api_key: os.environ/OPENAI_API_KEY + - model_name: gpt-3.5-turbo-end-user-test + litellm_params: + model: azure/chatgpt-v-2 + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_version: "2023-05-15" + api_key: os.environ/AZURE_API_KEY litellm_settings: success_callback: ["prometheus"] failure_callback: ["prometheus"] -general_settings: - master_key: sk-1234 - custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth \ No newline at end of file +general_settings: + master_key: sk-1234 + key_management_system: "google_secret_manager" \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index dd6869c661..d4e6e1b736 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1765,6 +1765,15 @@ class ProxyConfig: load_aws_secret_manager(use_aws_secret_manager=True) elif key_management_system == KeyManagementSystem.AWS_KMS.value: load_aws_kms(use_aws_kms=True) + elif ( + key_management_system + == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value + ): + from litellm.proxy.secret_managers.google_secret_manager import ( + GoogleSecretManager, + ) + + GoogleSecretManager() else: raise ValueError("Invalid Key Management System selected") key_management_settings = general_settings.get( diff --git a/litellm/proxy/secret_managers/google_secret_manager.py b/litellm/proxy/secret_managers/google_secret_manager.py new file mode 100644 index 0000000000..f30505c8f8 --- /dev/null +++ b/litellm/proxy/secret_managers/google_secret_manager.py @@ -0,0 +1,89 @@ +import base64 +import os +from typing import Optional + +import litellm +from litellm._logging import verbose_logger +from litellm.caching import InMemoryCache +from litellm.integrations.gcs_bucket_base import GCSBucketBase +from litellm.llms.custom_httpx.http_handler import _get_httpx_client +from litellm.proxy._types import KeyManagementSystem + + +class GoogleSecretManager(GCSBucketBase): + def __init__( + self, + refresh_interval: Optional[int] = 86400, + always_read_secret_manager: Optional[bool] = False, + ) -> None: + """ + Args: + refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours) + always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values + """ + super().__init__() + self.PROJECT_ID = "adroit-crow-413218" + self.sync_httpx_client = _get_httpx_client() + litellm.secret_manager_client = self + litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER + self.cache = InMemoryCache( + default_ttl=refresh_interval + ) # store in memory for 1 day + self.always_read_secret_manager = False + + def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]: + """ + Retrieve a secret from Google Secret Manager or cache. + + Args: + secret_name (str): The name of the secret. + + Returns: + str: The secret value if successful, None otherwise. + """ + if self.always_read_secret_manager is not True: + cached_secret = self.cache.get_cache(secret_name) + if cached_secret is not None: + return cached_secret + if secret_name in self.cache.cache_dict: + return cached_secret + + _secret_name = ( + f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest" + ) + headers = self.sync_construct_request_headers() + url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access" + + # Send the GET request to retrieve the secret + response = self.sync_httpx_client.get(url=url, headers=headers) + + if response.status_code != 200: + verbose_logger.error( + "Google Secret Manager retrieval error: %s", str(response.text) + ) + self.cache.set_cache( + secret_name, None + ) # Cache that the secret was not found + raise ValueError( + f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}" + ) + + verbose_logger.debug( + "Google Secret Manager retrieval response status code: %s", + response.status_code, + ) + + # Parse the JSON response and return the secret value + secret_data = response.json() + _base64_encoded_value = secret_data.get("payload", {}).get("data") + + # decode the base64 encoded value + if _base64_encoded_value is not None: + _decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8") + self.cache.set_cache( + secret_name, _decoded_value + ) # Cache the retrieved secret + return _decoded_value + + self.cache.set_cache(secret_name, None) # Cache that the secret was not found + raise ValueError(f"secret {secret_name} not found in Google Secret Manager") diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 2e69f28180..41de82ab66 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -4,10 +4,10 @@ from functools import partial from typing import Any, Coroutine, Dict, List, Literal, Optional, Union import litellm -from litellm import get_secret from litellm._logging import verbose_logger from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.togetherai.rerank import TogetherAIRerank +from litellm.secret_managers.main import get_secret from litellm.types.router import * from litellm.utils import supports_httpx_timeout diff --git a/litellm/secret_managers/Readme.md b/litellm/secret_managers/Readme.md new file mode 100644 index 0000000000..9b22689059 --- /dev/null +++ b/litellm/secret_managers/Readme.md @@ -0,0 +1,3 @@ +## Supported Secret Managers to read credentials from + +Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager \ No newline at end of file diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py new file mode 100644 index 0000000000..e136654c1c --- /dev/null +++ b/litellm/secret_managers/main.py @@ -0,0 +1,276 @@ +import ast +import base64 +import binascii +import json +import os +import sys +import traceback +from typing import Any, Optional, Union + +import httpx +from dotenv import load_dotenv + +import litellm +from litellm._logging import print_verbose, verbose_logger +from litellm.caching import DualCache +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.proxy._types import KeyManagementSystem + +oidc_cache = DualCache() + + +######### Secret Manager ############################ +# checks if user has passed in a secret manager client +# if passed in then checks the secret there +def _is_base64(s): + try: + return base64.b64encode(base64.b64decode(s)).decode() == s + except binascii.Error: + return False + + +def get_secret( + secret_name: str, + default_value: Optional[Union[str, bool]] = None, +): + key_management_system = litellm._key_management_system + key_management_settings = litellm._key_management_settings + + if secret_name.startswith("os.environ/"): + secret_name = secret_name.replace("os.environ/", "") + + # Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke + if secret_name.startswith("oidc/"): + secret_name_split = secret_name.replace("oidc/", "") + oidc_provider, oidc_aud = secret_name_split.split("/", 1) + # TODO: Add caching for HTTP requests + if oidc_provider == "google": + oidc_token = oidc_cache.get_cache(key=secret_name) + if oidc_token is not None: + return oidc_token + + oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature + response = oidc_client.get( + "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", + params={"audience": oidc_aud}, + headers={"Metadata-Flavor": "Google"}, + ) + if response.status_code == 200: + oidc_token = response.text + oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) + return oidc_token + else: + raise ValueError("Google OIDC provider failed") + elif oidc_provider == "circleci": + # https://circleci.com/docs/openid-connect-tokens/ + env_secret = os.getenv("CIRCLE_OIDC_TOKEN") + if env_secret is None: + raise ValueError("CIRCLE_OIDC_TOKEN not found in environment") + return env_secret + elif oidc_provider == "circleci_v2": + # https://circleci.com/docs/openid-connect-tokens/ + env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2") + if env_secret is None: + raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment") + return env_secret + elif oidc_provider == "github": + # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions + actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") + actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") + if ( + actions_id_token_request_url is None + or actions_id_token_request_token is None + ): + raise ValueError( + "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" + ) + + oidc_token = oidc_cache.get_cache(key=secret_name) + if oidc_token is not None: + return oidc_token + + oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + response = oidc_client.get( + actions_id_token_request_url, + params={"audience": oidc_aud}, + headers={ + "Authorization": f"Bearer {actions_id_token_request_token}", + "Accept": "application/json; api-version=2.0", + }, + ) + if response.status_code == 200: + oidc_token = response.text["value"] + oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) + return oidc_token + else: + raise ValueError("Github OIDC provider failed") + elif oidc_provider == "azure": + # https://azure.github.io/azure-workload-identity/docs/quick-start.html + azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") + if azure_federated_token_file is None: + raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") + with open(azure_federated_token_file, "r") as f: + oidc_token = f.read() + return oidc_token + elif oidc_provider == "file": + # Load token from a file + with open(oidc_aud, "r") as f: + oidc_token = f.read() + return oidc_token + elif oidc_provider == "env": + # Load token directly from an environment variable + oidc_token = os.getenv(oidc_aud) + if oidc_token is None: + raise ValueError(f"Environment variable {oidc_aud} not found") + return oidc_token + elif oidc_provider == "env_path": + # Load token from a file path specified in an environment variable + token_file_path = os.getenv(oidc_aud) + if token_file_path is None: + raise ValueError(f"Environment variable {oidc_aud} not found") + with open(token_file_path, "r") as f: + oidc_token = f.read() + return oidc_token + else: + raise ValueError("Unsupported OIDC provider") + + try: + if litellm.secret_manager_client is not None: + try: + client = litellm.secret_manager_client + key_manager = "local" + if key_management_system is not None: + key_manager = key_management_system.value + + if key_management_settings is not None: + if ( + secret_name not in key_management_settings.hosted_keys + ): # allow user to specify which keys to check in hosted key manager + key_manager = "local" + + if ( + key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value + or type(client).__module__ + "." + type(client).__name__ + == "azure.keyvault.secrets._client.SecretClient" + ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + secret = client.get_secret(secret_name).value + elif ( + key_manager == KeyManagementSystem.GOOGLE_KMS.value + or client.__class__.__name__ == "KeyManagementServiceClient" + ): + encrypted_secret: Any = os.getenv(secret_name) + if encrypted_secret is None: + raise ValueError( + f"Google KMS requires the encrypted secret to be in the environment!" + ) + b64_flag = _is_base64(encrypted_secret) + if b64_flag == True: # if passed in as encoded b64 string + encrypted_secret = base64.b64decode(encrypted_secret) + ciphertext = encrypted_secret + else: + raise ValueError( + f"Google KMS requires the encrypted secret to be encoded in base64" + ) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce + response = client.decrypt( + request={ + "name": litellm._google_kms_resource_name, + "ciphertext": ciphertext, + } + ) + secret = response.plaintext.decode( + "utf-8" + ) # assumes the original value was encoded with utf-8 + elif key_manager == KeyManagementSystem.AWS_KMS.value: + """ + Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. + """ + encrypted_value = os.getenv(secret_name, None) + if encrypted_value is None: + raise Exception( + "AWS KMS - Encrypted Value of Key={} is None".format( + secret_name + ) + ) + # Decode the base64 encoded ciphertext + ciphertext_blob = base64.b64decode(encrypted_value) + + # Set up the parameters for the decrypt call + params = {"CiphertextBlob": ciphertext_blob} + # Perform the decryption + response = client.decrypt(**params) + + # Extract and decode the plaintext + plaintext = response["Plaintext"] + secret = plaintext.decode("utf-8") + if isinstance(secret, str): + secret = secret.strip() + elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: + try: + get_secret_value_response = client.get_secret_value( + SecretId=secret_name + ) + print_verbose( + f"get_secret_value_response: {get_secret_value_response}" + ) + except Exception as e: + print_verbose(f"An error occurred - {str(e)}") + # For a list of exceptions thrown, see + # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + raise e + + # assume there is 1 secret per secret_name + secret_dict = json.loads(get_secret_value_response["SecretString"]) + print_verbose(f"secret_dict: {secret_dict}") + for k, v in secret_dict.items(): + secret = v + print_verbose(f"secret: {secret}") + if key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: + try: + secret = client.get_secret_from_google_secret_manager( + secret_name + ) + print_verbose(f"secret from google secret manager: {secret}") + if secret is None: + raise ValueError( + f"No secret found in Google Secret Manager for {secret_name}" + ) + except Exception as e: + print_verbose(f"An error occurred - {str(e)}") + raise e + elif key_manager == "local": + secret = os.getenv(secret_name) + else: # assume the default is infisicial client + secret = client.get_secret(secret_name).secret_value + except Exception as e: # check if it's in os.environ + verbose_logger.error( + f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}" + ) + secret = os.getenv(secret_name) + try: + secret_value_as_bool = ast.literal_eval(secret) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + else: + return secret + except: + return secret + else: + secret = os.environ.get(secret_name) + try: + secret_value_as_bool = ( + ast.literal_eval(secret) if secret is not None else None + ) + if isinstance(secret_value_as_bool, bool): + return secret_value_as_bool + else: + return secret + except Exception: + if default_value is not None: + return default_value + return secret + except Exception as e: + if default_value is not None: + return default_value + else: + raise e diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py index 1cf374148b..cd9b8a1466 100644 --- a/litellm/tests/test_secret_manager.py +++ b/litellm/tests/test_secret_manager.py @@ -16,10 +16,10 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest -from litellm import get_secret from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager +from litellm.secret_managers.main import get_secret @pytest.mark.skip(reason="AWS Suspended Account") diff --git a/litellm/utils.py b/litellm/utils.py index 26bf993aad..d5b9fde2dd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -68,6 +68,7 @@ from litellm.litellm_core_utils.redact_messages import ( ) from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.secret_managers.main import get_secret from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionNamedToolChoiceParam, @@ -93,8 +94,6 @@ from litellm.types.utils import ( Usage, ) -oidc_cache = DualCache() - try: # New and recommended way to access resources from importlib import resources @@ -8662,250 +8661,6 @@ def exception_type( raise raised_exc -######### Secret Manager ############################ -# checks if user has passed in a secret manager client -# if passed in then checks the secret there -def _is_base64(s): - try: - return base64.b64encode(base64.b64decode(s)).decode() == s - except binascii.Error: - return False - - -def get_secret( - secret_name: str, - default_value: Optional[Union[str, bool]] = None, -): - key_management_system = litellm._key_management_system - key_management_settings = litellm._key_management_settings - - if secret_name.startswith("os.environ/"): - secret_name = secret_name.replace("os.environ/", "") - - # Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke - if secret_name.startswith("oidc/"): - secret_name_split = secret_name.replace("oidc/", "") - oidc_provider, oidc_aud = secret_name_split.split("/", 1) - # TODO: Add caching for HTTP requests - if oidc_provider == "google": - oidc_token = oidc_cache.get_cache(key=secret_name) - if oidc_token is not None: - return oidc_token - - oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature - response = oidc_client.get( - "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", - params={"audience": oidc_aud}, - headers={"Metadata-Flavor": "Google"}, - ) - if response.status_code == 200: - oidc_token = response.text - oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) - return oidc_token - else: - raise ValueError("Google OIDC provider failed") - elif oidc_provider == "circleci": - # https://circleci.com/docs/openid-connect-tokens/ - env_secret = os.getenv("CIRCLE_OIDC_TOKEN") - if env_secret is None: - raise ValueError("CIRCLE_OIDC_TOKEN not found in environment") - return env_secret - elif oidc_provider == "circleci_v2": - # https://circleci.com/docs/openid-connect-tokens/ - env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2") - if env_secret is None: - raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment") - return env_secret - elif oidc_provider == "github": - # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions - actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") - actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") - if ( - actions_id_token_request_url is None - or actions_id_token_request_token is None - ): - raise ValueError( - "ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment" - ) - - oidc_token = oidc_cache.get_cache(key=secret_name) - if oidc_token is not None: - return oidc_token - - oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - response = oidc_client.get( - actions_id_token_request_url, - params={"audience": oidc_aud}, - headers={ - "Authorization": f"Bearer {actions_id_token_request_token}", - "Accept": "application/json; api-version=2.0", - }, - ) - if response.status_code == 200: - oidc_token = response.text["value"] - oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) - return oidc_token - else: - raise ValueError("Github OIDC provider failed") - elif oidc_provider == "azure": - # https://azure.github.io/azure-workload-identity/docs/quick-start.html - azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE") - if azure_federated_token_file is None: - raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment") - with open(azure_federated_token_file, "r") as f: - oidc_token = f.read() - return oidc_token - elif oidc_provider == "file": - # Load token from a file - with open(oidc_aud, "r") as f: - oidc_token = f.read() - return oidc_token - elif oidc_provider == "env": - # Load token directly from an environment variable - oidc_token = os.getenv(oidc_aud) - if oidc_token is None: - raise ValueError(f"Environment variable {oidc_aud} not found") - return oidc_token - elif oidc_provider == "env_path": - # Load token from a file path specified in an environment variable - token_file_path = os.getenv(oidc_aud) - if token_file_path is None: - raise ValueError(f"Environment variable {oidc_aud} not found") - with open(token_file_path, "r") as f: - oidc_token = f.read() - return oidc_token - else: - raise ValueError("Unsupported OIDC provider") - - try: - if litellm.secret_manager_client is not None: - try: - client = litellm.secret_manager_client - key_manager = "local" - if key_management_system is not None: - key_manager = key_management_system.value - - if key_management_settings is not None: - if ( - secret_name not in key_management_settings.hosted_keys - ): # allow user to specify which keys to check in hosted key manager - key_manager = "local" - - if ( - key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value - or type(client).__module__ + "." + type(client).__name__ - == "azure.keyvault.secrets._client.SecretClient" - ): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient - secret = client.get_secret(secret_name).value - elif ( - key_manager == KeyManagementSystem.GOOGLE_KMS.value - or client.__class__.__name__ == "KeyManagementServiceClient" - ): - encrypted_secret: Any = os.getenv(secret_name) - if encrypted_secret is None: - raise ValueError( - f"Google KMS requires the encrypted secret to be in the environment!" - ) - b64_flag = _is_base64(encrypted_secret) - if b64_flag == True: # if passed in as encoded b64 string - encrypted_secret = base64.b64decode(encrypted_secret) - ciphertext = encrypted_secret - else: - raise ValueError( - f"Google KMS requires the encrypted secret to be encoded in base64" - ) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce - response = client.decrypt( - request={ - "name": litellm._google_kms_resource_name, - "ciphertext": ciphertext, - } - ) - secret = response.plaintext.decode( - "utf-8" - ) # assumes the original value was encoded with utf-8 - elif key_manager == KeyManagementSystem.AWS_KMS.value: - """ - Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys. - """ - encrypted_value = os.getenv(secret_name, None) - if encrypted_value is None: - raise Exception( - "AWS KMS - Encrypted Value of Key={} is None".format( - secret_name - ) - ) - # Decode the base64 encoded ciphertext - ciphertext_blob = base64.b64decode(encrypted_value) - - # Set up the parameters for the decrypt call - params = {"CiphertextBlob": ciphertext_blob} - # Perform the decryption - response = client.decrypt(**params) - - # Extract and decode the plaintext - plaintext = response["Plaintext"] - secret = plaintext.decode("utf-8") - if isinstance(secret, str): - secret = secret.strip() - elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: - try: - get_secret_value_response = client.get_secret_value( - SecretId=secret_name - ) - print_verbose( - f"get_secret_value_response: {get_secret_value_response}" - ) - except Exception as e: - print_verbose(f"An error occurred - {str(e)}") - # For a list of exceptions thrown, see - # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html - raise e - - # assume there is 1 secret per secret_name - secret_dict = json.loads(get_secret_value_response["SecretString"]) - print_verbose(f"secret_dict: {secret_dict}") - for k, v in secret_dict.items(): - secret = v - print_verbose(f"secret: {secret}") - elif key_manager == "local": - secret = os.getenv(secret_name) - else: # assume the default is infisicial client - secret = client.get_secret(secret_name).secret_value - except Exception as e: # check if it's in os.environ - verbose_logger.error( - f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}" - ) - secret = os.getenv(secret_name) - try: - secret_value_as_bool = ast.literal_eval(secret) - if isinstance(secret_value_as_bool, bool): - return secret_value_as_bool - else: - return secret - except: - return secret - else: - secret = os.environ.get(secret_name) - try: - secret_value_as_bool = ( - ast.literal_eval(secret) if secret is not None else None - ) - if isinstance(secret_value_as_bool, bool): - return secret_value_as_bool - else: - return secret - except Exception: - if default_value is not None: - return default_value - return secret - except Exception as e: - if default_value is not None: - return default_value - else: - raise e - - ######## Streaming Class ############################ # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere