mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat: support for azure key vault
This commit is contained in:
parent
e8efde5a83
commit
284fb64f4d
6 changed files with 70 additions and 17 deletions
|
@ -58,6 +58,8 @@ num_retries: Optional[int] = None
|
||||||
fallbacks: Optional[List] = None
|
fallbacks: Optional[List] = None
|
||||||
context_window_fallbacks: Optional[List] = None
|
context_window_fallbacks: Optional[List] = None
|
||||||
allowed_fails: int = 0
|
allowed_fails: int = 0
|
||||||
|
####### SECRET MANAGERS #####################
|
||||||
|
secret_manager_client = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||||
#############################################
|
#############################################
|
||||||
|
|
||||||
def get_model_cost_map(url: str):
|
def get_model_cost_map(url: str):
|
||||||
|
@ -95,8 +97,6 @@ headers = None
|
||||||
api_version = None
|
api_version = None
|
||||||
organization = None
|
organization = None
|
||||||
config_path = None
|
config_path = None
|
||||||
####### Secret Manager #####################
|
|
||||||
secret_manager_client = None
|
|
||||||
####### COMPLETION MODELS ###################
|
####### COMPLETION MODELS ###################
|
||||||
open_ai_chat_completion_models: List = []
|
open_ai_chat_completion_models: List = []
|
||||||
open_ai_text_completion_models: List = []
|
open_ai_text_completion_models: List = []
|
||||||
|
@ -366,7 +366,8 @@ from .utils import (
|
||||||
encode,
|
encode,
|
||||||
decode,
|
decode,
|
||||||
_calculate_retry_after,
|
_calculate_retry_after,
|
||||||
_should_retry
|
_should_retry,
|
||||||
|
get_secret
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
|
|
@ -286,6 +286,44 @@ def celery_setup(use_queue: bool):
|
||||||
async_result = AsyncResult
|
async_result = AsyncResult
|
||||||
celery_app_conn = celery_app
|
celery_app_conn = celery_app
|
||||||
|
|
||||||
|
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
|
if use_azure_key_vault is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from azure.keyvault.secrets import SecretClient
|
||||||
|
from azure.identity import ClientSecretCredential
|
||||||
|
|
||||||
|
# Set your Azure Key Vault URI
|
||||||
|
KVUri = os.getenv("AZURE_KEY_VAULT_URI")
|
||||||
|
|
||||||
|
# Set your Azure AD application/client ID, client secret, and tenant ID
|
||||||
|
client_id = os.getenv("AZURE_CLIENT_ID")
|
||||||
|
client_secret = os.getenv("AZURE_CLIENT_SECRET")
|
||||||
|
tenant_id = os.getenv("AZURE_TENANT_ID")
|
||||||
|
|
||||||
|
# Initialize the ClientSecretCredential
|
||||||
|
credential = ClientSecretCredential(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# Create the SecretClient using the credential
|
||||||
|
client = SecretClient(vault_url=KVUri, credential=credential)
|
||||||
|
|
||||||
|
litellm.secret_manager_client = client
|
||||||
|
# # Retrieve all secrets
|
||||||
|
# secrets = client.get_secrets()
|
||||||
|
|
||||||
|
# # Load secrets into environment variables
|
||||||
|
# for secret in secrets:
|
||||||
|
# secret_name = secret.name
|
||||||
|
# secret_value = client.get_secret(secret_name).value
|
||||||
|
# os.environ[secret_name] = secret_value
|
||||||
|
|
||||||
|
print(f"test key - : {litellm.get_secret('test-3')}")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`")
|
||||||
|
|
||||||
|
|
||||||
def cost_tracking():
|
def cost_tracking():
|
||||||
global prisma_client, master_key
|
global prisma_client, master_key
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
|
@ -412,6 +450,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
### START REDIS QUEUE ###
|
### START REDIS QUEUE ###
|
||||||
use_queue = general_settings.get("use_queue", False)
|
use_queue = general_settings.get("use_queue", False)
|
||||||
celery_setup(use_queue=use_queue)
|
celery_setup(use_queue=use_queue)
|
||||||
|
### LOAD FROM AZURE KEY VAULT ###
|
||||||
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
|
|
|
@ -844,19 +844,19 @@ class Router:
|
||||||
api_key = litellm_params.get("api_key")
|
api_key = litellm_params.get("api_key")
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
api_key = os.getenv(api_key_env_name)
|
api_key = litellm.get_secret(api_key_env_name)
|
||||||
|
|
||||||
api_base = litellm_params.get("api_base")
|
api_base = litellm_params.get("api_base")
|
||||||
base_url = litellm_params.get("base_url")
|
base_url = litellm_params.get("base_url")
|
||||||
api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure
|
api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure
|
||||||
if api_base and api_base.startswith("os.environ/"):
|
if api_base and api_base.startswith("os.environ/"):
|
||||||
api_base_env_name = api_base.replace("os.environ/", "")
|
api_base_env_name = api_base.replace("os.environ/", "")
|
||||||
api_base = os.getenv(api_base_env_name)
|
api_base = litellm.get_secret(api_base_env_name)
|
||||||
|
|
||||||
api_version = litellm_params.get("api_version")
|
api_version = litellm_params.get("api_version")
|
||||||
if api_version and api_version.startswith("os.environ/"):
|
if api_version and api_version.startswith("os.environ/"):
|
||||||
api_version_env_name = api_version.replace("os.environ/", "")
|
api_version_env_name = api_version.replace("os.environ/", "")
|
||||||
api_version = os.getenv(api_version_env_name)
|
api_version = litellm.get_secret(api_version_env_name)
|
||||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||||
if "azure" in model_name:
|
if "azure" in model_name:
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
|
|
|
@ -2421,7 +2421,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
||||||
|
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
dynamic_api_key = os.getenv(api_key_env_name)
|
dynamic_api_key = get_secret(api_key_env_name)
|
||||||
# check if llm provider part of model name
|
# check if llm provider part of model name
|
||||||
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
|
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
@ -2429,15 +2429,15 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
||||||
if custom_llm_provider == "perplexity":
|
if custom_llm_provider == "perplexity":
|
||||||
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
# perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai
|
||||||
api_base = "https://api.perplexity.ai"
|
api_base = "https://api.perplexity.ai"
|
||||||
dynamic_api_key = os.getenv("PERPLEXITYAI_API_KEY")
|
dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY")
|
||||||
elif custom_llm_provider == "anyscale":
|
elif custom_llm_provider == "anyscale":
|
||||||
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
# anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||||
api_base = "https://api.endpoints.anyscale.com/v1"
|
api_base = "https://api.endpoints.anyscale.com/v1"
|
||||||
dynamic_api_key = os.getenv("ANYSCALE_API_KEY")
|
dynamic_api_key = get_secret("ANYSCALE_API_KEY")
|
||||||
elif custom_llm_provider == "deepinfra":
|
elif custom_llm_provider == "deepinfra":
|
||||||
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
# deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1
|
||||||
api_base = "https://api.deepinfra.com/v1/openai"
|
api_base = "https://api.deepinfra.com/v1/openai"
|
||||||
dynamic_api_key = os.getenv("DEEPINFRA_API_KEY")
|
dynamic_api_key = get_secret("DEEPINFRA_API_KEY")
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
# check if api base is a known openai compatible endpoint
|
# check if api base is a known openai compatible endpoint
|
||||||
|
@ -2446,13 +2446,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
||||||
if endpoint in api_base:
|
if endpoint in api_base:
|
||||||
if endpoint == "api.perplexity.ai":
|
if endpoint == "api.perplexity.ai":
|
||||||
custom_llm_provider = "perplexity"
|
custom_llm_provider = "perplexity"
|
||||||
dynamic_api_key = os.getenv("PERPLEXITYAI_API_KEY")
|
dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY")
|
||||||
elif endpoint == "api.endpoints.anyscale.com/v1":
|
elif endpoint == "api.endpoints.anyscale.com/v1":
|
||||||
custom_llm_provider = "anyscale"
|
custom_llm_provider = "anyscale"
|
||||||
dynamic_api_key = os.getenv("ANYSCALE_API_KEY")
|
dynamic_api_key = get_secret("ANYSCALE_API_KEY")
|
||||||
elif endpoint == "api.deepinfra.com/v1/openai":
|
elif endpoint == "api.deepinfra.com/v1/openai":
|
||||||
custom_llm_provider = "deepinfra"
|
custom_llm_provider = "deepinfra"
|
||||||
dynamic_api_key = os.getenv("DEEPINFRA_API_KEY")
|
dynamic_api_key = get_secret("DEEPINFRA_API_KEY")
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
|
|
||||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||||
|
@ -4715,13 +4715,17 @@ def litellm_telemetry(data):
|
||||||
# checks if user has passed in a secret manager client
|
# checks if user has passed in a secret manager client
|
||||||
# if passed in then checks the secret there
|
# if passed in then checks the secret there
|
||||||
def get_secret(secret_name):
|
def get_secret(secret_name):
|
||||||
if litellm.secret_manager_client != None:
|
if litellm.secret_manager_client is not None:
|
||||||
# TODO: check which secret manager is being used
|
# TODO: check which secret manager is being used
|
||||||
# currently only supports Infisical
|
# currently only supports Infisical
|
||||||
try:
|
try:
|
||||||
secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
|
client = litellm.secret_manager_client
|
||||||
except:
|
if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||||
secret = None
|
secret = retrieved_secret = client.get_secret(secret_name).value
|
||||||
|
else: # assume the default is infisicial client
|
||||||
|
secret = client.get_secret(secret_name).secret_value
|
||||||
|
except: # check if it's in os.environ
|
||||||
|
secret = os.environ.get(secret_name)
|
||||||
return secret
|
return secret
|
||||||
else:
|
else:
|
||||||
return os.environ.get(secret_name)
|
return os.environ.get(secret_name)
|
||||||
|
|
|
@ -33,6 +33,12 @@ proxy = [
|
||||||
"orjson",
|
"orjson",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
extra_proxy = [
|
||||||
|
"prisma",
|
||||||
|
"azure-identity",
|
||||||
|
"azure-keyvault-secrets"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
litellm = 'litellm:run_server'
|
litellm = 'litellm:run_server'
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# LITELLM PROXY DEPENDENCIES #
|
||||||
litellm
|
litellm
|
||||||
openai
|
openai
|
||||||
fastapi
|
fastapi
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue