forked from phoenix/litellm-mirror
Merge pull request #5489 from BerriAI/litellm_Add_secret_managers
[Feat] Add Google Secret Manager Support
This commit is contained in:
commit
19dbfff620
32 changed files with 633 additions and 327 deletions
|
@ -23,7 +23,7 @@ This covers:
|
|||
- ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs)
|
||||
- ✅ [JWT-Auth](../docs/proxy/token_auth.md)
|
||||
- ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes)
|
||||
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](./proxy/enterprise#beta-aws-key-manager---key-decryption)
|
||||
- ✅ [**Secret Managers** AWS Key Manager, Google Secret Manager, Azure Key](./secret)
|
||||
- ✅ IP address‑based access control lists
|
||||
- ✅ Track Request IP Address
|
||||
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
|
||||
|
|
|
@ -17,7 +17,7 @@ Features:
|
|||
- ✅ [Audit Logs with retention policy](#audit-logs)
|
||||
- ✅ [JWT-Auth](../docs/proxy/token_auth.md)
|
||||
- ✅ [Control available public, private routes](#control-available-public-private-routes)
|
||||
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](#beta-aws-key-manager---key-decryption)
|
||||
- ✅ [**Secret Managers** AWS Key Manager, Google Secret Manager, Azure Key](../secret)
|
||||
- ✅ IP address‑based access control lists
|
||||
- ✅ Track Request IP Address
|
||||
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
|
||||
|
|
|
@ -1,9 +1,22 @@
|
|||
# Secret Manager
|
||||
LiteLLM supports reading secrets from Azure Key Vault and Infisical
|
||||
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
|
||||
|
||||
- AWS Key Managemenet Service
|
||||
:::info
|
||||
|
||||
✨ **This is an Enterprise Feature**
|
||||
|
||||
[Enterprise Pricing](https://www.litellm.ai/#pricing)
|
||||
|
||||
[Contact us here to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
## Supported Secret Managers
|
||||
|
||||
- AWS Key Management Service
|
||||
- AWS Secret Manager
|
||||
- [Azure Key Vault](#azure-key-vault)
|
||||
- [Google Secret Manager](#google-secret-manager)
|
||||
- Google Key Management Service
|
||||
- [Infisical Secret Manager](#infisical-secret-manager)
|
||||
- [.env Files](#env-files)
|
||||
|
@ -125,6 +138,45 @@ litellm --config /path/to/config.yaml
|
|||
|
||||
[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
|
||||
|
||||
## Google Secret Manager
|
||||
|
||||
Support for [Google Secret Manager](https://cloud.google.com/security/products/secret-manager)
|
||||
|
||||
|
||||
1. Save Google Secret Manager details in your environment
|
||||
|
||||
```shell
|
||||
GOOGLE_SECRET_MANAGER_PROJECT_ID="your-project-id-on-gcp" # example: adroit-crow-413218
|
||||
```
|
||||
|
||||
Optional Params
|
||||
|
||||
```shell
|
||||
export GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL = "" # (int) defaults to 86400
|
||||
export GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER = "" # (str) set to "true" if you want to always read from google secret manager without using in memory caching. NOT RECOMMENDED in PROD
|
||||
```
|
||||
|
||||
2. Add to proxy config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
api_key: os.environ/OPENAI_API_KEY # this will be read from Google Secret Manager
|
||||
|
||||
general_settings:
|
||||
key_management_system: "google_secret_manager"
|
||||
```
|
||||
|
||||
You can now test this by starting your proxy:
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
[Quick Test Proxy](./proxy/quick_start#using-litellm-proxy---curl-request-openai-package-langchain-langchain-js)
|
||||
|
||||
|
||||
## Google Key Management Service
|
||||
|
||||
Use encrypted keys from Google KMS on the proxy
|
||||
|
|
|
@ -834,7 +834,6 @@ from .utils import (
|
|||
decode,
|
||||
_calculate_retry_after,
|
||||
_should_retry,
|
||||
get_secret,
|
||||
get_supported_openai_params,
|
||||
get_api_base,
|
||||
get_first_chars_messages,
|
||||
|
|
|
@ -22,6 +22,7 @@ import litellm
|
|||
from litellm import client
|
||||
from litellm.llms.azure import AzureBatchesAPI
|
||||
from litellm.llms.openai import OpenAIBatchesAPI
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
|
@ -34,7 +35,7 @@ from litellm.types.llms.openai import (
|
|||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import get_secret, supports_httpx_timeout
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
|
|
|
@ -17,7 +17,6 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI
|
||||
from litellm.llms.fine_tuning_apis.openai import (
|
||||
|
@ -26,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
|
|||
OpenAIFineTuningAPI,
|
||||
)
|
||||
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import Hyperparameters
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
|
|
@ -52,6 +52,27 @@ class GCSBucketBase(CustomLogger):
|
|||
|
||||
return headers
|
||||
|
||||
def sync_construct_request_headers(self) -> Dict[str, str]:
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="gcs-bucket",
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
verbose_logger.debug("constructed auth_header %s", auth_header)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}", # auth_header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return headers
|
||||
|
||||
async def download_gcs_object(self, object_name):
|
||||
"""
|
||||
Download an object from GCS.
|
||||
|
|
|
@ -5,7 +5,7 @@ import httpx
|
|||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import DualCache, InMemoryCache
|
||||
from litellm.utils import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import List, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
|
|
|
@ -11,7 +11,6 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm.llms.cohere.embed import embedding as cohere_embedding
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -19,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
|
|
@ -1063,6 +1063,7 @@ class KeyManagementSystem(enum.Enum):
|
|||
GOOGLE_KMS = "google_kms"
|
||||
AZURE_KEY_VAULT = "azure_key_vault"
|
||||
AWS_SECRET_MANAGER = "aws_secret_manager"
|
||||
GOOGLE_SECRET_MANAGER = "google_secret_manager"
|
||||
LOCAL = "local"
|
||||
AWS_KMS = "aws_kms"
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ def init_rds_client(
|
|||
aws_web_identity_token: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
from litellm import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm.proxy._types import (
|
||||
DynamoDBArgs,
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLM_Config,
|
||||
LiteLLM_UserTable,
|
||||
)
|
||||
from litellm.proxy.utils import hash_token
|
||||
from litellm import get_secret
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
DynamoDBArgs,
|
||||
LiteLLM_Config,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLM_VerificationToken,
|
||||
)
|
||||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm.proxy.utils import hash_token
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
class DynamoDBWrapper(CustomDB):
|
||||
|
@ -21,19 +22,19 @@ class DynamoDBWrapper(CustomDB):
|
|||
def __init__(self, database_arguments: DynamoDBArgs):
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
from yarl import URL
|
||||
|
||||
self.throughput_type = None
|
||||
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
||||
|
@ -59,7 +60,9 @@ class DynamoDBWrapper(CustomDB):
|
|||
verbose_proxy_logger.debug(
|
||||
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
|
||||
)
|
||||
import boto3, os
|
||||
import os
|
||||
|
||||
import boto3
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
|
@ -92,22 +95,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
"""
|
||||
Connect to DB, and creating / updating any tables
|
||||
"""
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
verbose_proxy_logger.debug("DynamoDB Wrapper - Attempting to connect")
|
||||
self.set_env_vars_based_on_arn()
|
||||
|
@ -192,22 +195,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
async def insert_data(
|
||||
self, value: Any, table_name: Literal["user", "key", "config", "spend"]
|
||||
):
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
self.set_env_vars_based_on_arn()
|
||||
|
||||
|
@ -237,22 +240,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
return await table.put_item(item=value, return_values=ReturnValues.all_old)
|
||||
|
||||
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
self.set_env_vars_based_on_arn()
|
||||
|
||||
|
@ -311,22 +314,22 @@ class DynamoDBWrapper(CustomDB):
|
|||
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
self.set_env_vars_based_on_arn()
|
||||
import aiohttp
|
||||
from aiodynamo.client import Client
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
from aiodynamo.expressions import F, UpdateExpression, Value
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiodynamo.http.httpx import HTTPX
|
||||
from aiodynamo.models import (
|
||||
Throughput,
|
||||
KeySchema,
|
||||
KeySpec,
|
||||
KeyType,
|
||||
PayPerRequest,
|
||||
ReturnValues,
|
||||
Throughput,
|
||||
)
|
||||
from yarl import URL
|
||||
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||
from aiodynamo.models import ReturnValues
|
||||
from aiodynamo.http.aiohttp import AIOHTTP
|
||||
from aiohttp import ClientSession
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
if self.database_arguments.ssl_verify == False:
|
||||
client_session = ClientSession(connector=aiohttp.TCPConnector(ssl=False))
|
||||
|
|
|
@ -24,7 +24,6 @@ import httpx
|
|||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
@ -38,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
BedrockContentItem,
|
||||
BedrockRequest,
|
||||
|
|
|
@ -19,12 +19,12 @@ import httpx
|
|||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
GuardrailItem,
|
||||
LakeraCategoryThresholds,
|
||||
|
|
|
@ -9,7 +9,7 @@ import time
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("./")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||
from litellm.secret_managers.aws_secret_manager import decrypt_env_var
|
||||
|
||||
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
|
||||
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
|
||||
|
|
|
@ -475,7 +475,7 @@ def run_server(
|
|||
|
||||
### DECRYPT ENV VAR ###
|
||||
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||
from litellm.secret_managers.aws_secret_manager import decrypt_env_var
|
||||
|
||||
if (
|
||||
os.getenv("USE_AWS_KMS", None) is not None
|
||||
|
@ -544,6 +544,15 @@ def run_server(
|
|||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
|
||||
):
|
||||
from litellm.secret_managers.google_secret_manager import (
|
||||
GoogleSecretManager,
|
||||
)
|
||||
|
||||
GoogleSecretManager()
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
@ -598,7 +607,7 @@ def run_server(
|
|||
or os.getenv("DIRECT_URL", None) is not None
|
||||
):
|
||||
try:
|
||||
from litellm import get_secret
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
if os.getenv("DATABASE_URL", None) is not None:
|
||||
### add connection pool + pool timeout args
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
model_info:
|
||||
id: "team-a-model" # used for identifying model in response headers
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["prometheus"]
|
||||
failure_callback: ["prometheus"]
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
key_management_system: "google_secret_manager"
|
|
@ -212,11 +212,6 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
)
|
||||
from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||
router as spend_management_router,
|
||||
)
|
||||
|
@ -257,6 +252,11 @@ from litellm.router import (
|
|||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm.router import updateDeployment
|
||||
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
||||
from litellm.secret_managers.aws_secret_manager import (
|
||||
load_aws_kms,
|
||||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.secret_managers.google_kms import load_google_kms
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicResponse,
|
||||
|
@ -1765,6 +1765,15 @@ class ProxyConfig:
|
|||
load_aws_secret_manager(use_aws_secret_manager=True)
|
||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||
load_aws_kms(use_aws_kms=True)
|
||||
elif (
|
||||
key_management_system
|
||||
== KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
|
||||
):
|
||||
from litellm.secret_managers.google_secret_manager import (
|
||||
GoogleSecretManager,
|
||||
)
|
||||
|
||||
GoogleSecretManager()
|
||||
else:
|
||||
raise ValueError("Invalid Key Management System selected")
|
||||
key_management_settings = general_settings.get(
|
||||
|
|
|
@ -4,10 +4,10 @@ from functools import partial
|
|||
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.togetherai.rerank import TogetherAIRerank
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import openai
|
|||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.proxy.secret_managers.get_azure_ad_token_provider import (
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
|
3
litellm/secret_managers/Readme.md
Normal file
3
litellm/secret_managers/Readme.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
## Supported Secret Managers to read credentials from
|
||||
|
||||
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager
|
|
@ -14,8 +14,7 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
|
|||
Returns:
|
||||
Callable that returns a temporary authentication token.
|
||||
"""
|
||||
from azure.identity import ClientSecretCredential
|
||||
from azure.identity import get_bearer_token_provider
|
||||
from azure.identity import ClientSecretCredential, get_bearer_token_provider
|
||||
|
||||
try:
|
||||
credential = ClientSecretCredential(
|
||||
|
@ -24,7 +23,9 @@ def get_azure_ad_token_provider() -> Callable[[], str]:
|
|||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
)
|
||||
except KeyError as e:
|
||||
raise ValueError("Missing environment variable required by Azure AD workflow.") from e
|
||||
raise ValueError(
|
||||
"Missing environment variable required by Azure AD workflow."
|
||||
) from e
|
||||
|
||||
return get_bearer_token_provider(
|
||||
credential,
|
|
@ -7,8 +7,11 @@ Requires:
|
|||
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
|
||||
* `pip install google-cloud-kms`
|
||||
"""
|
||||
import litellm, os
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
116
litellm/secret_managers/google_secret_manager.py
Normal file
116
litellm/secret_managers/google_secret_manager.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.integrations.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.proxy._types import CommonProxyErrors, KeyManagementSystem
|
||||
|
||||
|
||||
class GoogleSecretManager(GCSBucketBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_interval: Optional[int] = 86400,
|
||||
always_read_secret_manager: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
|
||||
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Google Secret Manager requires an Enterprise License {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
super().__init__()
|
||||
self.PROJECT_ID = os.environ.get("GOOGLE_SECRET_MANAGER_PROJECT_ID", None)
|
||||
if self.PROJECT_ID is None:
|
||||
raise ValueError(
|
||||
"Google Secret Manager requires a project ID, please set 'GOOGLE_SECRET_MANAGER_PROJECT_ID' in your .env"
|
||||
)
|
||||
self.sync_httpx_client = _get_httpx_client()
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
|
||||
_refresh_interval = os.environ.get(
|
||||
"GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL", refresh_interval
|
||||
)
|
||||
_refresh_interval = (
|
||||
int(_refresh_interval) if _refresh_interval else refresh_interval
|
||||
)
|
||||
self.cache = InMemoryCache(
|
||||
default_ttl=_refresh_interval
|
||||
) # store in memory for 1 day
|
||||
|
||||
_always_read_secret_manager = os.environ.get(
|
||||
"GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER",
|
||||
)
|
||||
if (
|
||||
_always_read_secret_manager
|
||||
and _always_read_secret_manager.lower() == "true"
|
||||
):
|
||||
self.always_read_secret_manager = True
|
||||
else:
|
||||
# by default this should be False, we want to use in memory caching for this. It's a bad idea to fetch from secret manager for all requests
|
||||
self.always_read_secret_manager = always_read_secret_manager or False
|
||||
|
||||
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve a secret from Google Secret Manager or cache.
|
||||
|
||||
Args:
|
||||
secret_name (str): The name of the secret.
|
||||
|
||||
Returns:
|
||||
str: The secret value if successful, None otherwise.
|
||||
"""
|
||||
if self.always_read_secret_manager is not True:
|
||||
cached_secret = self.cache.get_cache(secret_name)
|
||||
if cached_secret is not None:
|
||||
return cached_secret
|
||||
if secret_name in self.cache.cache_dict:
|
||||
return cached_secret
|
||||
|
||||
_secret_name = (
|
||||
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
|
||||
)
|
||||
headers = self.sync_construct_request_headers()
|
||||
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
|
||||
|
||||
# Send the GET request to retrieve the secret
|
||||
response = self.sync_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"Google Secret Manager retrieval error: %s", str(response.text)
|
||||
)
|
||||
self.cache.set_cache(
|
||||
secret_name, None
|
||||
) # Cache that the secret was not found
|
||||
raise ValueError(
|
||||
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Google Secret Manager retrieval response status code: %s",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
# Parse the JSON response and return the secret value
|
||||
secret_data = response.json()
|
||||
_base64_encoded_value = secret_data.get("payload", {}).get("data")
|
||||
|
||||
# decode the base64 encoded value
|
||||
if _base64_encoded_value is not None:
|
||||
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
|
||||
self.cache.set_cache(
|
||||
secret_name, _decoded_value
|
||||
) # Cache the retrieved secret
|
||||
return _decoded_value
|
||||
|
||||
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
|
||||
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")
|
276
litellm/secret_managers/main.py
Normal file
276
litellm/secret_managers/main.py
Normal file
|
@ -0,0 +1,276 @@
|
|||
import ast
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def _is_base64(s):
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text["value"]
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
elif oidc_provider == "azure":
|
||||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "file":
|
||||
# Load token from a file
|
||||
with open(oidc_aud, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "env":
|
||||
# Load token directly from an environment variable
|
||||
oidc_token = os.getenv(oidc_aud)
|
||||
if oidc_token is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
return oidc_token
|
||||
elif oidc_provider == "env_path":
|
||||
# Load token from a file path specified in an environment variable
|
||||
token_file_path = os.getenv(oidc_aud)
|
||||
if token_file_path is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
with open(token_file_path, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag == True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
ciphertext = encrypted_secret
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
print_verbose(
|
||||
f"get_secret_value_response: {get_secret_value_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
|
||||
# assume there is 1 secret per secret_name
|
||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||
print_verbose(f"secret_dict: {secret_dict}")
|
||||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
if key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||
try:
|
||||
secret = client.get_secret_from_google_secret_manager(
|
||||
secret_name
|
||||
)
|
||||
print_verbose(f"secret from google secret manager: {secret}")
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Google Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = (
|
||||
ast.literal_eval(secret) if secret is not None else None
|
||||
)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
|
@ -83,7 +83,7 @@ async def test_router_init():
|
|||
)
|
||||
|
||||
|
||||
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
|
||||
@patch("litellm.secret_managers.get_azure_ad_token_provider.os")
|
||||
def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secret(
|
||||
mocked_os_lib: MagicMock,
|
||||
) -> None:
|
||||
|
@ -128,7 +128,7 @@ def test_router_init_with_neither_api_key_nor_azure_service_principal_with_secre
|
|||
|
||||
@patch("azure.identity.get_bearer_token_provider")
|
||||
@patch("azure.identity.ClientSecretCredential")
|
||||
@patch("litellm.proxy.secret_managers.get_azure_ad_token_provider.os")
|
||||
@patch("litellm.secret_managers.get_azure_ad_token_provider.os")
|
||||
def test_router_init_azure_service_principal_with_secret_with_environment_variables(
|
||||
mocked_os_lib: MagicMock,
|
||||
mocked_credential: MagicMock,
|
||||
|
|
|
@ -16,10 +16,10 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
from litellm import get_secret
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
|
@ -189,3 +189,54 @@ def test_oidc_env_path():
|
|||
assert secret_val == secret_value
|
||||
|
||||
del os.environ[env_var_name]
|
||||
|
||||
|
||||
def test_google_secret_manager():
|
||||
"""
|
||||
Test that we can get a secret from Google Secret Manager
|
||||
"""
|
||||
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "adroit-crow-413218"
|
||||
from test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||
|
||||
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
|
||||
|
||||
# load_vertex_ai_credentials()
|
||||
secret_manager = GoogleSecretManager()
|
||||
|
||||
secret_val = secret_manager.get_secret_from_google_secret_manager(
|
||||
secret_name="OPENAI_API_KEY"
|
||||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
|
||||
assert (
|
||||
secret_val == "anything"
|
||||
), "did not get expected secret value. expect 'anything', got '{}'".format(
|
||||
secret_val
|
||||
)
|
||||
|
||||
|
||||
def test_google_secret_manager_read_in_memory():
|
||||
"""
|
||||
Test that Google Secret manager returs in memory value when it exists
|
||||
"""
|
||||
from test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||
|
||||
from litellm.secret_managers.google_secret_manager import GoogleSecretManager
|
||||
|
||||
# load_vertex_ai_credentials()
|
||||
os.environ["GOOGLE_SECRET_MANAGER_PROJECT_ID"] = "adroit-crow-413218"
|
||||
secret_manager = GoogleSecretManager()
|
||||
secret_manager.cache.cache_dict["UNIQUE_KEY"] = None
|
||||
secret_manager.cache.cache_dict["UNIQUE_KEY_2"] = "lite-llm"
|
||||
|
||||
secret_val = secret_manager.get_secret_from_google_secret_manager(
|
||||
secret_name="UNIQUE_KEY"
|
||||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
assert secret_val == None
|
||||
|
||||
secret_val = secret_manager.get_secret_from_google_secret_manager(
|
||||
secret_name="UNIQUE_KEY_2"
|
||||
)
|
||||
print("secret_val: {}".format(secret_val))
|
||||
assert secret_val == "lite-llm"
|
||||
|
|
|
@ -3712,6 +3712,7 @@ def test_unit_test_custom_stream_wrapper_function_call():
|
|||
"vertex_ai/claude-3-5-sonnet@20240620",
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_streaming_tool_calls_valid_json_str(model):
|
||||
if "vertex_ai" in model:
|
||||
from litellm.tests.test_amazing_vertex_completion import (
|
||||
|
|
247
litellm/utils.py
247
litellm/utils.py
|
@ -68,6 +68,7 @@ from litellm.litellm_core_utils.redact_messages import (
|
|||
)
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
|
@ -93,8 +94,6 @@ from litellm.types.utils import (
|
|||
Usage,
|
||||
)
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
try:
|
||||
# New and recommended way to access resources
|
||||
from importlib import resources
|
||||
|
@ -8662,250 +8661,6 @@ def exception_type(
|
|||
raise raised_exc
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def _is_base64(s):
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text["value"]
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
elif oidc_provider == "azure":
|
||||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "file":
|
||||
# Load token from a file
|
||||
with open(oidc_aud, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "env":
|
||||
# Load token directly from an environment variable
|
||||
oidc_token = os.getenv(oidc_aud)
|
||||
if oidc_token is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
return oidc_token
|
||||
elif oidc_provider == "env_path":
|
||||
# Load token from a file path specified in an environment variable
|
||||
token_file_path = os.getenv(oidc_aud)
|
||||
if token_file_path is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
with open(token_file_path, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag == True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
ciphertext = encrypted_secret
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
print_verbose(
|
||||
f"get_secret_value_response: {get_secret_value_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
|
||||
# assume there is 1 secret per secret_name
|
||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||
print_verbose(f"secret_dict: {secret_dict}")
|
||||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = (
|
||||
ast.literal_eval(secret) if secret is not None else None
|
||||
)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
######## Streaming Class ############################
|
||||
# wraps the completion stream to return the correct format for the model
|
||||
# replicate/anthropic/cohere
|
||||
|
|
|
@ -17,7 +17,7 @@ def test_decrypt_and_reset_env():
|
|||
os.environ["DATABASE_URL"] = (
|
||||
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||
)
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||
from litellm.secret_managers.aws_secret_manager import (
|
||||
decrypt_and_reset_env_var,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue