mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix+feat(router.py): Fix missing azure_ad_token, and allow use OIDC auth
This commit is contained in:
parent
b48f95d9ef
commit
5205d3913e
1 changed files with 9 additions and 0 deletions
|
@ -45,6 +45,7 @@ from litellm.types.router import (
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -2089,6 +2090,10 @@ class Router:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||||
)
|
)
|
||||||
|
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||||
|
if azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = "2023-07-01-preview"
|
api_version = "2023-07-01-preview"
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
|
@ -2099,6 +2104,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_async_client"
|
cache_key = f"{model_id}_async_client"
|
||||||
_client = openai.AsyncAzureOpenAI(
|
_client = openai.AsyncAzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -2123,6 +2129,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_client"
|
cache_key = f"{model_id}_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -2147,6 +2154,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_stream_async_client"
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
|
@ -2171,6 +2179,7 @@ class Router:
|
||||||
cache_key = f"{model_id}_stream_client"
|
cache_key = f"{model_id}_stream_client"
|
||||||
_client = openai.AzureOpenAI( # type: ignore
|
_client = openai.AzureOpenAI( # type: ignore
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=stream_timeout,
|
timeout=stream_timeout,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue