fix+feat(router.py): Fix missing azure_ad_token, and allow use OIDC auth

This commit is contained in:
David Manouchehri 2024-05-07 19:43:18 +00:00
parent e268354acc
commit 9a0bb36865
No known key found for this signature in database

View file

@ -45,6 +45,7 @@ from litellm.types.router import (
RetryPolicy,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
class Router:
@ -2089,6 +2090,10 @@ class Router:
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
@ -2099,6 +2104,7 @@ class Router:
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -2123,6 +2129,7 @@ class Router:
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
@ -2147,6 +2154,7 @@ class Router:
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
@ -2171,6 +2179,7 @@ class Router:
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,