add new litellm params for client_id, tenant_id etc

This commit is contained in:
Ishaan Jaff 2024-08-22 11:37:30 -07:00
parent 8f657b40f5
commit 08fa3f346a
5 changed files with 168 additions and 9 deletions

View file

@ -173,10 +173,13 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider", None)
if azure_ad_token_provider is not None:
if litellm_params.get("tenant_id"):
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id()
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
tenant_id=litellm_params.get("tenant_id"),
client_id=litellm_params.get("client_id"),
client_secret=litellm_params.get("client_secret"),
)
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str):
@ -507,13 +510,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
) # cache for 1 hr
def get_azure_ad_token_from_entrata_id() -> Callable[[], str]:
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
def get_azure_ad_token_from_entrata_id(
tenant_id: str, client_id: str, client_secret: str
) -> Callable[[], str]:
from azure.identity import (
ClientSecretCredential,
DefaultAzureCredential,
get_bearer_token_provider,
)
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
if tenant_id.startswith("os.environ/"):
tenant_id = litellm.get_secret(tenant_id)
if client_id.startswith("os.environ/"):
client_id = litellm.get_secret(client_id)
if client_secret.startswith("os.environ/"):
client_secret = litellm.get_secret(client_secret)
verbose_router_logger.debug(
"tenant_id %s, client_id %s, client_secret %s",
tenant_id,
client_id,
client_secret,
)
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
verbose_router_logger.debug("credential %s", credential)
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
credential, "https://cognitiveservices.azure.com/.default"
)
verbose_router_logger.debug("token_provider %s", token_provider)