forked from phoenix/litellm-mirror
refactor get_secret
This commit is contained in:
parent
1546a82f18
commit
b0178a85cf
20 changed files with 457 additions and 307 deletions
|
@ -834,7 +834,6 @@ from .utils import (
|
|||
decode,
|
||||
_calculate_retry_after,
|
||||
_should_retry,
|
||||
get_secret,
|
||||
get_supported_openai_params,
|
||||
get_api_base,
|
||||
get_first_chars_messages,
|
||||
|
|
|
@ -22,6 +22,7 @@ import litellm
|
|||
from litellm import client
|
||||
from litellm.llms.azure import AzureBatchesAPI
|
||||
from litellm.llms.openai import OpenAIBatchesAPI
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
|
@ -34,7 +35,7 @@ from litellm.types.llms.openai import (
|
|||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import get_secret, supports_httpx_timeout
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
|
|
|
@ -17,7 +17,6 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI
|
||||
from litellm.llms.fine_tuning_apis.openai import (
|
||||
|
@ -26,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
|
|||
OpenAIFineTuningAPI,
|
||||
)
|
||||
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import Hyperparameters
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
|
|
@ -5,7 +5,7 @@ import httpx
|
|||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import DualCache, InMemoryCache
|
||||
from litellm.utils import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import List, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm.llms.cohere.embed import embedding as cohere_embedding
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -19,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
|
|
@ -1063,6 +1063,7 @@ class KeyManagementSystem(enum.Enum):
|
|||
GOOGLE_KMS = "google_kms"
|
||||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||
GOOGLE_SECRET_MANAGER = "google_secret_manager"
|
||||
LOCAL = "local"
|
||||
AWS_KMS = "aws_kms"
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ def init_rds_client(
|
|||
aws_web_identity_token: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
from litellm import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm.proxy._types import (
|
||||
DynamoDBArgs,
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLM_Config,
|
||||
LiteLLM_UserTable,
|
||||
)
|
||||
from litellm.proxy.utils import hash_token
|
||||
from litellm import get_secret
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
DynamoDBArgs,
|
||||
LiteLLM_Config,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_VerificationToken,
|
||||
)
|
||||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm.proxy.utils import hash_token
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
class DynamoDBWrapper(CustomDB):
|
||||
|
@ -21,19 +22,19 @@ class DynamoDBWrapper(CustomDB):
|
|||
def __init__(self, database_arguments: DynamoDBArgs):
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
from yarl import URL
|
||||
|
||||
self.throughput_type = None
|
||||
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
||||
|
@ -59,7 +60,9 @@ class DynamoDBWrapper(CustomDB):
|
|||
verbose_proxy_logger.debug(
|
||||
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
|
||||
)
|
||||
import boto3, os
|
||||
import os
|
||||
|
||||
import boto3
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
|
@ -92,22 +95,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
"""
|
||||
Connect to DB, and creating / updating any tables
|
||||
"""
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
verbose_proxy_logger.debug("DynamoDB Wrapper - Attempting to connect")
|
||||
self.set_env_vars_based_on_arn()
|
||||
|
@ -192,22 +195,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
async def insert_data(
|
||||
self, value: Any, table_name: Literal["user", "key", "config", "spend"]
|
||||
):
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
self.set_env_vars_based_on_arn()
|
||||
|
||||
|
@ -237,22 +240,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
return await table.put_item(item=value, return_values=ReturnValues.all_old)
|
||||
|
||||
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
self.set_env_vars_based_on_arn()
|
||||
|
||||
|
@ -311,22 +314,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
self.set_env_vars_based_on_arn()
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
if self.database_arguments.ssl_verify == False:
|
||||
client_session = ClientSession(connector=aiohttp.TCPConnector(ssl=False))
|
||||
|
|
|
@ -24,7 +24,6 @@ import httpx
|
|||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
@ -38,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
BedrockContentItem,
|
||||
BedrockRequest,
|
||||
|
|
|
@ -19,12 +19,12 @@ import httpx
|
|||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
GuardrailItem,
|
||||
LakeraCategoryThresholds,
|
||||
|
|
|
@ -544,6 +544,15 @@ def run_server(
|
|||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
|
||||
):
|
||||
from litellm.proxy.secret_managers.google_secret_manager import (
|
||||
GoogleSecretManager,
|
||||
)
|
||||
|
||||
GoogleSecretManager()
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
@ -598,7 +607,7 @@ def run_server(
|
|||
or os.getenv("DIRECT_URL", None) is not None
|
||||
):
|
||||
try:
|
||||
from litellm import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
if os.getenv("DATABASE_URL", None) is not None:
|
||||
### add connection pool + pool timeout args
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
model_info:
|
||||
id: "team-a-model" # used for identifying model in response headers
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["prometheus"]
|
||||
failure_callback: ["prometheus"]
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
key_management_system: "google_secret_manager"
|
|
@ -1765,6 +1765,15 @@ class ProxyConfig:
|
|||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
|
||||
):
|
||||
from litellm.proxy.secret_managers.google_secret_manager import (
|
||||
GoogleSecretManager,
|
||||
)
|
||||
|
||||
GoogleSecretManager()
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
|
89
litellm/proxy/secret_managers/google_secret_manager.py
Normal file
89
litellm/proxy/secret_managers/google_secret_manager.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.integrations.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
class GoogleSecretManager(GCSBucketBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_interval: Optional[int] = 86400,
|
||||
always_read_secret_manager: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
|
||||
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
|
||||
"""
|
||||
super().__init__()
|
||||
self.PROJECT_ID = "adroit-crow-413218"
|
||||
self.sync_httpx_client = _get_httpx_client()
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
|
||||
self.cache = InMemoryCache(
|
||||
default_ttl=refresh_interval
|
||||
) # store in memory for 1 day
|
||||
self.always_read_secret_manager = False
|
||||
|
||||
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve a secret from Google Secret Manager or cache.
|
||||
|
||||
Args:
|
||||
secret_name (str): The name of the secret.
|
||||
|
||||
Returns:
|
||||
str: The secret value if successful, None otherwise.
|
||||
"""
|
||||
if self.always_read_secret_manager is not True:
|
||||
cached_secret = self.cache.get_cache(secret_name)
|
||||
if cached_secret is not None:
|
||||
return cached_secret
|
||||
if secret_name in self.cache.cache_dict:
|
||||
return cached_secret
|
||||
|
||||
_secret_name = (
|
||||
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
|
||||
)
|
||||
headers = self.sync_construct_request_headers()
|
||||
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
|
||||
|
||||
# Send the GET request to retrieve the secret
|
||||
response = self.sync_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"Google Secret Manager retrieval error: %s", str(response.text)
|
||||
)
|
||||
self.cache.set_cache(
|
||||
secret_name, None
|
||||
) # Cache that the secret was not found
|
||||
raise ValueError(
|
||||
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Google Secret Manager retrieval response status code: %s",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
# Parse the JSON response and return the secret value
|
||||
secret_data = response.json()
|
||||
_base64_encoded_value = secret_data.get("payload", {}).get("data")
|
||||
|
||||
# decode the base64 encoded value
|
||||
if _base64_encoded_value is not None:
|
||||
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
|
||||
self.cache.set_cache(
|
||||
secret_name, _decoded_value
|
||||
) # Cache the retrieved secret
|
||||
return _decoded_value
|
||||
|
||||
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
|
||||
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")
|
|
@ -4,10 +4,10 @@ from functools import partial
|
|||
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
|
|
3
litellm/secret_managers/Readme.md
Normal file
3
litellm/secret_managers/Readme.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
## Supported Secret Managers to read credentials from
|
||||
|
||||
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager
|
276
litellm/secret_managers/main.py
Normal file
276
litellm/secret_managers/main.py
Normal file
|
@ -0,0 +1,276 @@
|
|||
import ast
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def _is_base64(s):
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text["value"]
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
elif oidc_provider == "azure":
|
||||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "file":
|
||||
# Load token from a file
|
||||
with open(oidc_aud, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "env":
|
||||
# Load token directly from an environment variable
|
||||
oidc_token = os.getenv(oidc_aud)
|
||||
if oidc_token is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
return oidc_token
|
||||
elif oidc_provider == "env_path":
|
||||
# Load token from a file path specified in an environment variable
|
||||
token_file_path = os.getenv(oidc_aud)
|
||||
if token_file_path is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
with open(token_file_path, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag == True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
ciphertext = encrypted_secret
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
print_verbose(
|
||||
f"get_secret_value_response: {get_secret_value_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
|
||||
# assume there is 1 secret per secret_name
|
||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||
print_verbose(f"secret_dict: {secret_dict}")
|
||||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
if key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||
try:
|
||||
secret = client.get_secret_from_google_secret_manager(
|
||||
secret_name
|
||||
)
|
||||
print_verbose(f"secret from google secret manager: {secret}")
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Google Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = (
|
||||
ast.literal_eval(secret) if secret is not None else None
|
||||
)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
|
@ -16,10 +16,10 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
from litellm import get_secret
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
|
|
247
litellm/utils.py
247
litellm/utils.py
|
@ -68,6 +68,7 @@ from litellm.litellm_core_utils.redact_messages import (
|
|||
)
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
|
@ -93,8 +94,6 @@ from litellm.types.utils import (
|
|||
Usage,
|
||||
)
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
try:
|
||||
# New and recommended way to access resources
|
||||
from importlib import resources
|
||||
|
@ -8662,250 +8661,6 @@ def exception_type(
|
|||
raise raised_exc
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def _is_base64(s):
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text["value"]
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
elif oidc_provider == "azure":
|
||||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "file":
|
||||
# Load token from a file
|
||||
with open(oidc_aud, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "env":
|
||||
# Load token directly from an environment variable
|
||||
oidc_token = os.getenv(oidc_aud)
|
||||
if oidc_token is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
return oidc_token
|
||||
elif oidc_provider == "env_path":
|
||||
# Load token from a file path specified in an environment variable
|
||||
token_file_path = os.getenv(oidc_aud)
|
||||
if token_file_path is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
with open(token_file_path, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag == True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
ciphertext = encrypted_secret
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
print_verbose(
|
||||
f"get_secret_value_response: {get_secret_value_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
|
||||
# assume there is 1 secret per secret_name
|
||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||
print_verbose(f"secret_dict: {secret_dict}")
|
||||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = (
|
||||
ast.literal_eval(secret) if secret is not None else None
|
||||
)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
######## Streaming Class ############################
|
||||
# wraps the completion stream to return the correct format for the model
|
||||
# replicate/anthropic/cohere
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue