refactor get_secret

This commit is contained in:
Ishaan Jaff 2024-09-03 10:42:12 -07:00
parent 1546a82f18
commit b0178a85cf
20 changed files with 457 additions and 307 deletions

View file

@ -834,7 +834,6 @@ from .utils import (
decode,
_calculate_retry_after,
_should_retry,
get_secret,
get_supported_openai_params,
get_api_base,
get_first_chars_messages,

View file

@ -22,6 +22,7 @@ import litellm
from litellm import client
from litellm.llms.azure import AzureBatchesAPI
from litellm.llms.openai import OpenAIBatchesAPI
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
@ -34,7 +35,7 @@ from litellm.types.llms.openai import (
RetrieveBatchRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import get_secret, supports_httpx_timeout
from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()

View file

@ -17,7 +17,6 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm import get_secret
from litellm._logging import verbose_logger
from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI
from litellm.llms.fine_tuning_apis.openai import (
@ -26,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
OpenAIFineTuningAPI,
)
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import Hyperparameters
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout

View file

@ -5,7 +5,7 @@ import httpx
from litellm._logging import verbose_logger
from litellm.caching import DualCache, InMemoryCache
from litellm.utils import get_secret
from litellm.secret_managers.main import get_secret
from .base import BaseLLM

View file

@ -10,7 +10,7 @@ from typing import List, Optional, Union
import httpx
import litellm
from litellm import get_secret
from litellm.secret_managers.main import get_secret
class BedrockError(Exception):

View file

@ -11,7 +11,6 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm import get_secret
from litellm.llms.cohere.embed import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
@ -19,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client,
)
from litellm.secret_managers.main import get_secret
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
from litellm.types.utils import Embedding, EmbeddingResponse, Usage

View file

@ -1063,6 +1063,7 @@ class KeyManagementSystem(enum.Enum):
GOOGLE_KMS = "google_kms"
AZURE_KEY_VAULT = "azure_key_vault"
AWS_SECRET_MANAGER = "aws_secret_manager"
GOOGLE_SECRET_MANAGER = "google_secret_manager"
LOCAL = "local"
AWS_KMS = "aws_kms"

View file

@ -14,7 +14,7 @@ def init_rds_client(
aws_web_identity_token: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
from litellm import get_secret
from litellm.secret_managers.main import get_secret
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)

View file

@ -1,16 +1,17 @@
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm.proxy.utils import hash_token
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
import json
from datetime import datetime
from typing import Any, List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_Config,
LiteLLM_UserTable,
LiteLLM_VerificationToken,
)
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy.utils import hash_token
from litellm.secret_managers.main import get_secret
class DynamoDBWrapper(CustomDB):
@ -21,19 +22,19 @@ class DynamoDBWrapper(CustomDB):
def __init__(self, database_arguments: DynamoDBArgs):
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
ReturnValues,
Throughput,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from yarl import URL
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
@ -59,7 +60,9 @@ class DynamoDBWrapper(CustomDB):
verbose_proxy_logger.debug(
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
)
import boto3, os
import os
import boto3
sts_client = boto3.client("sts")
@ -92,22 +95,22 @@ class DynamoDBWrapper(CustomDB):
"""
Connect to DB, and creating / updating any tables
"""
import aiohttp
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
ReturnValues,
Throughput,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
import aiohttp
from yarl import URL
verbose_proxy_logger.debug("DynamoDB Wrapper - Attempting to connect")
self.set_env_vars_based_on_arn()
@ -192,22 +195,22 @@ class DynamoDBWrapper(CustomDB):
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config", "spend"]
):
import aiohttp
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
ReturnValues,
Throughput,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
import aiohttp
from yarl import URL
self.set_env_vars_based_on_arn()
@ -237,22 +240,22 @@ class DynamoDBWrapper(CustomDB):
return await table.put_item(item=value, return_values=ReturnValues.all_old)
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
import aiohttp
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
ReturnValues,
Throughput,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
import aiohttp
from yarl import URL
self.set_env_vars_based_on_arn()
@ -311,22 +314,22 @@ class DynamoDBWrapper(CustomDB):
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
):
self.set_env_vars_based_on_arn()
import aiohttp
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.expressions import F, UpdateExpression, Value
from aiodynamo.http.aiohttp import AIOHTTP
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import (
Throughput,
KeySchema,
KeySpec,
KeyType,
PayPerRequest,
ReturnValues,
Throughput,
)
from yarl import URL
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
import aiohttp
from yarl import URL
if self.database_arguments.ssl_verify == False:
client_session = ClientSession(connector=aiohttp.TCPConnector(ssl=False))

View file

@ -24,7 +24,6 @@ import httpx
from fastapi import HTTPException
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
@ -38,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import (
BedrockContentItem,
BedrockRequest,

View file

@ -19,12 +19,12 @@ import httpx
from fastapi import HTTPException
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import (
GuardrailItem,
LakeraCategoryThresholds,

View file

@ -544,6 +544,15 @@ def run_server(
load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (
key_management_system
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
):
from litellm.proxy.secret_managers.google_secret_manager import (
GoogleSecretManager,
)
GoogleSecretManager()
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(
@ -598,7 +607,7 @@ def run_server(
or os.getenv("DIRECT_URL", None) is not None
):
try:
from litellm import get_secret
from litellm.secret_managers.main import get_secret
if os.getenv("DATABASE_URL", None) is not None:
### add connection pool + pool timeout args

View file

@ -1,15 +1,20 @@
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
model_info:
id: "team-a-model" # used for identifying model in response headers
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_base: https://exampleopenaiendpoint-production.up.railway.app/
api_key: os.environ/OPENAI_API_KEY
- model_name: gpt-3.5-turbo-end-user-test
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
general_settings:
master_key: sk-1234
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth
general_settings:
master_key: sk-1234
key_management_system: "google_secret_manager"

View file

@ -1765,6 +1765,15 @@ class ProxyConfig:
load_aws_secret_manager(use_aws_secret_manager=True)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (
key_management_system
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
):
from litellm.proxy.secret_managers.google_secret_manager import (
GoogleSecretManager,
)
GoogleSecretManager()
else:
raise ValueError("Invalid Key Management System selected")
key_management_settings = general_settings.get(

View file

@ -0,0 +1,89 @@
import base64
import os
from typing import Optional
import litellm
from litellm._logging import verbose_logger
from litellm.caching import InMemoryCache
from litellm.integrations.gcs_bucket_base import GCSBucketBase
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.proxy._types import KeyManagementSystem
class GoogleSecretManager(GCSBucketBase):
def __init__(
self,
refresh_interval: Optional[int] = 86400,
always_read_secret_manager: Optional[bool] = False,
) -> None:
"""
Args:
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
"""
super().__init__()
self.PROJECT_ID = "adroit-crow-413218"
self.sync_httpx_client = _get_httpx_client()
litellm.secret_manager_client = self
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
self.cache = InMemoryCache(
default_ttl=refresh_interval
) # store in memory for 1 day
self.always_read_secret_manager = False
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
"""
Retrieve a secret from Google Secret Manager or cache.
Args:
secret_name (str): The name of the secret.
Returns:
str: The secret value if successful, None otherwise.
"""
if self.always_read_secret_manager is not True:
cached_secret = self.cache.get_cache(secret_name)
if cached_secret is not None:
return cached_secret
if secret_name in self.cache.cache_dict:
return cached_secret
_secret_name = (
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
)
headers = self.sync_construct_request_headers()
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
# Send the GET request to retrieve the secret
response = self.sync_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"Google Secret Manager retrieval error: %s", str(response.text)
)
self.cache.set_cache(
secret_name, None
) # Cache that the secret was not found
raise ValueError(
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
)
verbose_logger.debug(
"Google Secret Manager retrieval response status code: %s",
response.status_code,
)
# Parse the JSON response and return the secret value
secret_data = response.json()
_base64_encoded_value = secret_data.get("payload", {}).get("data")
# decode the base64 encoded value
if _base64_encoded_value is not None:
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
self.cache.set_cache(
secret_name, _decoded_value
) # Cache the retrieved secret
return _decoded_value
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")

View file

@ -4,10 +4,10 @@ from functools import partial
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
import litellm
from litellm import get_secret
from litellm._logging import verbose_logger
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.togetherai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout

View file

@ -0,0 +1,3 @@
## Supported Secret Managers to read credentials from
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager

View file

@ -0,0 +1,276 @@
import ast
import base64
import binascii
import json
import os
import sys
import traceback
from typing import Any, Optional, Union
import httpx
from dotenv import load_dotenv
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.caching import DualCache
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.proxy._types import KeyManagementSystem
oidc_cache = DualCache()
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text["value"]
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "file":
# Load token from a file
with open(oidc_aud, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "env":
# Load token directly from an environment variable
oidc_token = os.getenv(oidc_aud)
if oidc_token is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
return oidc_token
elif oidc_provider == "env_path":
# Load token from a file path specified in an environment variable
token_file_path = os.getenv(oidc_aud)
if token_file_path is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
with open(token_file_path, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
ciphertext = encrypted_secret
else:
raise ValueError(
f"Google KMS requires the encrypted secret to be encoded in base64"
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
if key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(
secret_name
)
print_verbose(f"secret from google secret manager: {secret}")
if secret is None:
raise ValueError(
f"No secret found in Google Secret Manager for {secret_name}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
raise e
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
verbose_logger.error(
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else:
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except Exception:
if default_value is not None:
return default_value
return secret
except Exception as e:
if default_value is not None:
return default_value
else:
raise e

View file

@ -16,10 +16,10 @@ sys.path.insert(
) # Adds the parent directory to the system path
import pytest
from litellm import get_secret
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.secret_managers.main import get_secret
@pytest.mark.skip(reason="AWS Suspended Account")

View file

@ -68,6 +68,7 @@ from litellm.litellm_core_utils.redact_messages import (
)
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionNamedToolChoiceParam,
@ -93,8 +94,6 @@ from litellm.types.utils import (
Usage,
)
oidc_cache = DualCache()
try:
# New and recommended way to access resources
from importlib import resources
@ -8662,250 +8661,6 @@ def exception_type(
raise raised_exc
######### Secret Manager ############################
# checks if user has passed in a secret manager client
# if passed in then checks the secret there
def _is_base64(s):
try:
return base64.b64encode(base64.b64decode(s)).decode() == s
except binascii.Error:
return False
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
if secret_name.startswith("oidc/"):
secret_name_split = secret_name.replace("oidc/", "")
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
# TODO: Add caching for HTTP requests
if oidc_provider == "google":
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
response = oidc_client.get(
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
params={"audience": oidc_aud},
headers={"Metadata-Flavor": "Google"},
)
if response.status_code == 200:
oidc_token = response.text
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
return oidc_token
else:
raise ValueError("Google OIDC provider failed")
elif oidc_provider == "circleci":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
return env_secret
elif oidc_provider == "circleci_v2":
# https://circleci.com/docs/openid-connect-tokens/
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
if env_secret is None:
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
return env_secret
elif oidc_provider == "github":
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
if (
actions_id_token_request_url is None
or actions_id_token_request_token is None
):
raise ValueError(
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
)
oidc_token = oidc_cache.get_cache(key=secret_name)
if oidc_token is not None:
return oidc_token
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = oidc_client.get(
actions_id_token_request_url,
params={"audience": oidc_aud},
headers={
"Authorization": f"Bearer {actions_id_token_request_token}",
"Accept": "application/json; api-version=2.0",
},
)
if response.status_code == 200:
oidc_token = response.text["value"]
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "file":
# Load token from a file
with open(oidc_aud, "r") as f:
oidc_token = f.read()
return oidc_token
elif oidc_provider == "env":
# Load token directly from an environment variable
oidc_token = os.getenv(oidc_aud)
if oidc_token is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
return oidc_token
elif oidc_provider == "env_path":
# Load token from a file path specified in an environment variable
token_file_path = os.getenv(oidc_aud)
if token_file_path is None:
raise ValueError(f"Environment variable {oidc_aud} not found")
with open(token_file_path, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
try:
client = litellm.secret_manager_client
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
or type(client).__module__ + "." + type(client).__name__
== "azure.keyvault.secrets._client.SecretClient"
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
secret = client.get_secret(secret_name).value
elif (
key_manager == KeyManagementSystem.GOOGLE_KMS.value
or client.__class__.__name__ == "KeyManagementServiceClient"
):
encrypted_secret: Any = os.getenv(secret_name)
if encrypted_secret is None:
raise ValueError(
f"Google KMS requires the encrypted secret to be in the environment!"
)
b64_flag = _is_base64(encrypted_secret)
if b64_flag == True: # if passed in as encoded b64 string
encrypted_secret = base64.b64decode(encrypted_secret)
ciphertext = encrypted_secret
else:
raise ValueError(
f"Google KMS requires the encrypted secret to be encoded in base64"
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
response = client.decrypt(
request={
"name": litellm._google_kms_resource_name,
"ciphertext": ciphertext,
}
)
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_KMS.value:
"""
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
"""
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(
secret_name
)
)
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == "local":
secret = os.getenv(secret_name)
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
verbose_logger.error(
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
)
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except:
return secret
else:
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:
return secret
except Exception:
if default_value is not None:
return default_value
return secret
except Exception as e:
if default_value is not None:
return default_value
else:
raise e
######## Streaming Class ############################
# wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere