mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
add new litellm params for client_id, tenant_id etc
This commit is contained in:
parent
8f657b40f5
commit
08fa3f346a
5 changed files with 168 additions and 9 deletions
|
@ -173,10 +173,13 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
organization = litellm.get_secret(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
|
||||
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider", None)
|
||||
if azure_ad_token_provider is not None:
|
||||
if litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id()
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=litellm_params.get("tenant_id"),
|
||||
client_id=litellm_params.get("client_id"),
|
||||
client_secret=litellm_params.get("client_secret"),
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
|
@ -507,13 +510,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
) # cache for 1 hr
|
||||
|
||||
|
||||
def get_azure_ad_token_from_entrata_id() -> Callable[[], str]:
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
def get_azure_ad_token_from_entrata_id(
|
||||
tenant_id: str, client_id: str, client_secret: str
|
||||
) -> Callable[[], str]:
|
||||
from azure.identity import (
|
||||
ClientSecretCredential,
|
||||
DefaultAzureCredential,
|
||||
get_bearer_token_provider,
|
||||
)
|
||||
|
||||
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
|
||||
|
||||
if tenant_id.startswith("os.environ/"):
|
||||
tenant_id = litellm.get_secret(tenant_id)
|
||||
|
||||
if client_id.startswith("os.environ/"):
|
||||
client_id = litellm.get_secret(client_id)
|
||||
|
||||
if client_secret.startswith("os.environ/"):
|
||||
client_secret = litellm.get_secret(client_secret)
|
||||
verbose_router_logger.debug(
|
||||
"tenant_id %s, client_id %s, client_secret %s",
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret,
|
||||
)
|
||||
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
|
||||
|
||||
verbose_router_logger.debug("credential %s", credential)
|
||||
|
||||
token_provider = get_bearer_token_provider(
|
||||
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
||||
credential, "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
|
||||
verbose_router_logger.debug("token_provider %s", token_provider)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue