From 4b655d8b3343be0d5a537be59e58c12f29820c89 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 15:02:37 +0000 Subject: [PATCH 1/7] feat(util.py): Add OIDC support. --- litellm/tests/test_secret_manager.py | 33 +++++++++++++++++ litellm/utils.py | 54 ++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py index 3ea38f806..892a8831c 100644 --- a/litellm/tests/test_secret_manager.py +++ b/litellm/tests/test_secret_manager.py @@ -23,3 +23,36 @@ def test_aws_secret_manager(): print(f"secret_val: {secret_val}") assert secret_val == "sk-1234" + + +def redact_oidc_signature(secret_val): + # remove the last part of `.` and replace it with "SIGNATURE_REMOVED" + return secret_val.split(".")[:-1] + ["SIGNATURE_REMOVED"] + + +@pytest.mark.skipif(os.environ.get('K_SERVICE') is None, reason="Cannot run without being in GCP Cloud Run") +def test_oidc_google(): + secret_val = get_secret("oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke") + + print(f"secret_val: {redact_oidc_signature(secret_val)}") + + +@pytest.mark.skipif(os.environ.get('ACTIONS_ID_TOKEN_REQUEST_TOKEN') is None, reason="Cannot run without being in GitHub Actions") +def test_oidc_github(): + secret_val = get_secret("oidc/github/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke") + + print(f"secret_val: {redact_oidc_signature(secret_val)}") + + +@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN') is None, reason="Cannot run without being in a CircleCI Runner") +def test_oidc_circleci(): + secret_val = get_secret("oidc/circleci/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke") + + print(f"secret_val: {redact_oidc_signature(secret_val)}") + + +@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="Cannot run without being in a CircleCI Runner") +def test_oidc_circleci_v2(): + secret_val = get_secret("oidc/circleci_v2/https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke") + + print(f"secret_val: {redact_oidc_signature(secret_val)}") diff --git a/litellm/utils.py b/litellm/utils.py index a938a0ba8..652406f64 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -33,6 +33,7 @@ from dataclasses import ( ) import litellm._service_logger # for storing API inputs, outputs, and metadata +from litellm.llms.custom_httpx.http_handler import HTTPHandler try: # this works in python 3.8 @@ -9288,6 +9289,59 @@ def get_secret( if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") + # Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke + if secret_name.startswith("oidc/"): + secret_name = secret_name.replace("oidc/", "") + oidc_provider, oidc_aud = secret_name.split("/", 1) + # TODO: Add caching for HTTP requests + match oidc_provider: + case "google": + client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature + response = client.get( + "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity", + params={"audience": oidc_aud}, + headers={"Metadata-Flavor": "Google"}, + ) + if response.status_code == 200: + return response.text + else: + raise ValueError("Google OIDC provider failed") + case "circleci": + # https://circleci.com/docs/openid-connect-tokens/ + env_secret = os.getenv("CIRCLE_OIDC_TOKEN") + if env_secret is None: + raise ValueError("CIRCLE_OIDC_TOKEN not found in environment") + return env_secret + case "circleci_v2": + # https://circleci.com/docs/openid-connect-tokens/ + env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2") + if env_secret is None: + raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment") + return env_secret + case "github": + # https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions + actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL") + actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") + if actions_id_token_request_url is None or actions_id_token_request_token is None: + raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment") + client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + response = client.get( + actions_id_token_request_url, + params={"audience": oidc_aud}, + headers={ + "Authorization": f"Bearer {actions_id_token_request_token}", + "Accept": "application/json; api-version=2.0", + }, + ) + if response.status_code == 200: + return response.text['value'] + else: + raise ValueError("Github OIDC provider failed") + case _: + raise ValueError("Unsupported OIDC provider") + + try: if litellm.secret_manager_client is not None: try: From 3ee0328b044e84e6bd7c3085681a468e69c3c646 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 15:01:46 +0000 Subject: [PATCH 2/7] feat(bedrock.py): Support using OIDC tokens. --- litellm/llms/bedrock.py | 42 +++++++++++++++++++++++- litellm/tests/test_bedrock_completion.py | 29 ++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 2f26ae4a9..fe0aca44e 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -550,6 +550,7 @@ def init_bedrock_client( aws_session_name: Optional[str] = None, aws_profile_name: Optional[str] = None, aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, extra_headers: Optional[dict] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, ): @@ -566,6 +567,7 @@ def init_bedrock_client( aws_session_name, aws_profile_name, aws_role_name, + aws_web_identity_token, ] # Iterate over parameters and update if needed @@ -581,6 +583,7 @@ def init_bedrock_client( aws_session_name, aws_profile_name, aws_role_name, + aws_web_identity_token, ) = params_to_check ### SET REGION NAME @@ -619,7 +622,38 @@ def init_bedrock_client( config = boto3.session.Config() ### CHECK STS ### - if aws_role_name is not None and aws_session_name is not None: + if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None: + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client( + "sts" + ) + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + ) + elif aws_role_name is not None and aws_session_name is not None: # use sts if role name passed in sts_client = boto3.client( "sts", @@ -752,6 +786,7 @@ def completion( aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = optional_params.pop("aws_bedrock_client", None) @@ -766,6 +801,7 @@ def completion( aws_role_name=aws_role_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, + aws_web_identity_token=aws_web_identity_token, extra_headers=extra_headers, timeout=timeout, ) @@ -1288,6 +1324,7 @@ def embedding( aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = init_bedrock_client( @@ -1295,6 +1332,7 @@ def embedding( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_web_identity_token=aws_web_identity_token, aws_role_name=aws_role_name, aws_session_name=aws_session_name, ) @@ -1377,6 +1415,7 @@ def image_generation( aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) # use passed in BedrockRuntime.Client if provided, otherwise create a new one client = init_bedrock_client( @@ -1384,6 +1423,7 @@ def image_generation( aws_secret_access_key=aws_secret_access_key, aws_region_name=aws_region_name, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_web_identity_token=aws_web_identity_token, aws_role_name=aws_role_name, aws_session_name=aws_session_name, timeout=timeout, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 3f5c831d7..ef6774fd2 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -206,6 +206,35 @@ def test_completion_bedrock_claude_sts_client_auth(): # test_completion_bedrock_claude_sts_client_auth() +@pytest.mark.skipif(os.environ.get('CIRCLE_OIDC_TOKEN_V2') is None, reason="CIRCLE_OIDC_TOKEN_V2 is not set") +def test_completion_bedrock_claude_sts_oidc_auth(): + print("\ncalling bedrock claude with oidc auth") + import os + + aws_web_identity_token = "oidc/circleci_v2/" + aws_region_name = os.environ["AWS_REGION_NAME"] + aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"] + + try: + litellm.set_verbose = True + + response = completion( + model="bedrock/anthropic.claude-instant-v1", + messages=messages, + max_tokens=10, + temperature=0.1, + aws_region_name=aws_region_name, + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name="my-test-session", + ) + # Add any assertions here to check the response + print(response) + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_bedrock_extra_headers(): try: From e268354acc3dbdbf0a2313aa7fd87fd485774b34 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 19:18:28 +0000 Subject: [PATCH 3/7] feat(azure.py): Support OIDC auth --- litellm/llms/azure.py | 66 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e7af9d43b..c2bbe54c1 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -8,6 +8,7 @@ from litellm.utils import ( CustomStreamWrapper, convert_to_model_response_object, TranscriptionResponse, + get_secret, ) from typing import Callable, Optional, BinaryIO from litellm import OpenAIConfig @@ -16,6 +17,7 @@ import httpx from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from openai import AzureOpenAI, AsyncAzureOpenAI import uuid +import os class AzureOpenAIError(Exception): @@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): return azure_client_params +def get_azure_ad_token_from_oidc(azure_ad_token: str): + azure_client_id = os.getenv("AZURE_CLIENT_ID", None) + azure_tenant = os.getenv("AZURE_TENANT_ID", None) + + if azure_client_id is None or azure_tenant is None: + raise AzureOpenAIError( + status_code=422, + message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", + ) + + oidc_token = get_secret(azure_ad_token) + + if oidc_token is None: + raise AzureOpenAIError( + status_code=401, + message="OIDC token could not be retrieved from secret manager.", + ) + + req_token = httpx.get( + f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token", + data={ + "client_id": azure_client_id, + "grant_type": "client_credentials", + "scope": "https://cognitiveservices.azure.com/.default", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": oidc_token, + }, + ) + + if req_token.status_code != 200: + raise AzureOpenAIError( + status_code=req_token.status_code, + message=req_token.text, + ) + + possible_azure_ad_token = req_token.json().get("access_token", None) + + if possible_azure_ad_token is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token not returned" + ) + + return possible_azure_ad_token + + class AzureChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: headers["api-key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) headers["Authorization"] = f"Bearer {azure_ad_token}" return headers @@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + azure_client_params["azure_ad_token"] = azure_ad_token if acompletion is True: @@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if client is None: azure_client = AzureOpenAI(**azure_client_params) @@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token # setting Azure client @@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if client is None: azure_client = AzureOpenAI(**azure_client_params) @@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if client is None: azure_client = AsyncAzureOpenAI(**azure_client_params) @@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token ## LOGGING @@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if aimg_generation == True: @@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM): if api_key is not None: azure_client_params["api_key"] = api_key elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token if max_retries is not None: From 9a0bb36865a3ed10325eab67f5d74169fd152891 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 19:43:18 +0000 Subject: [PATCH 4/7] fix+feat(router.py): Fix missing azure_ad_token, and allow use OIDC auth --- litellm/router.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/litellm/router.py b/litellm/router.py index 4353da804..24a926e5d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -45,6 +45,7 @@ from litellm.types.router import ( RetryPolicy, ) from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.azure import get_azure_ad_token_from_oidc class Router: @@ -2089,6 +2090,10 @@ class Router: raise ValueError( f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" ) + azure_ad_token = litellm_params.get("azure_ad_token") + if azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) if api_version is None: api_version = "2023-07-01-preview" if "gateway.ai.cloudflare.com" in api_base: @@ -2099,6 +2104,7 @@ class Router: cache_key = f"{model_id}_async_client" _client = openai.AsyncAzureOpenAI( api_key=api_key, + azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=timeout, @@ -2123,6 +2129,7 @@ class Router: cache_key = f"{model_id}_client" _client = openai.AzureOpenAI( # type: ignore api_key=api_key, + azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=timeout, @@ -2147,6 +2154,7 @@ class Router: cache_key = f"{model_id}_stream_async_client" _client = openai.AsyncAzureOpenAI( # type: ignore api_key=api_key, + azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=stream_timeout, @@ -2171,6 +2179,7 @@ class Router: cache_key = f"{model_id}_stream_client" _client = openai.AzureOpenAI( # type: ignore api_key=api_key, + azure_ad_token=azure_ad_token, base_url=api_base, api_version=api_version, timeout=stream_timeout, From cb49fb004d7bba866d229760931256bf96971f8e Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 19:51:57 +0000 Subject: [PATCH 5/7] fix(azure.py): Correct invalid .get to a .post for OIDC --- litellm/llms/azure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index c2bbe54c1..1e807b5e7 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -146,7 +146,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): message="OIDC token could not be retrieved from secret manager.", ) - req_token = httpx.get( + req_token = httpx.post( f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token", data={ "client_id": azure_client_id, From 44b1b219115f9d4804ca9dec2b82441059a3c128 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Tue, 7 May 2024 21:20:15 +0000 Subject: [PATCH 6/7] feat(utils.py) - Add OIDC caching for Google Cloud Run and GitHub Actions. --- litellm/utils.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index 652406f64..f5d3b974b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -34,6 +34,8 @@ from dataclasses import ( import litellm._service_logger # for storing API inputs, outputs, and metadata from litellm.llms.custom_httpx.http_handler import HTTPHandler +from litellm.caching import DualCache +oidc_cache = DualCache() try: # this works in python 3.8 @@ -9291,11 +9293,15 @@ def get_secret( # Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke if secret_name.startswith("oidc/"): - secret_name = secret_name.replace("oidc/", "") - oidc_provider, oidc_aud = secret_name.split("/", 1) + secret_name_split = secret_name.replace("oidc/", "") + oidc_provider, oidc_aud = secret_name_split.split("/", 1) # TODO: Add caching for HTTP requests match oidc_provider: case "google": + oidc_token = oidc_cache.get_cache(key=secret_name) + if oidc_token is not None: + return oidc_token + client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) # https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature response = client.get( @@ -9304,7 +9310,9 @@ def get_secret( headers={"Metadata-Flavor": "Google"}, ) if response.status_code == 200: - return response.text + oidc_token = response.text + oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60) + return oidc_token else: raise ValueError("Google OIDC provider failed") case "circleci": @@ -9325,6 +9333,11 @@ def get_secret( actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") if actions_id_token_request_url is None or actions_id_token_request_token is None: raise ValueError("ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment") + + oidc_token = oidc_cache.get_cache(key=secret_name) + if oidc_token is not None: + return oidc_token + client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) response = client.get( actions_id_token_request_url, @@ -9335,7 +9348,9 @@ def get_secret( }, ) if response.status_code == 200: - return response.text['value'] + oidc_token = response.text['value'] + oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5) + return oidc_token else: raise ValueError("Github OIDC provider failed") case _: From 77b3acb396269ac31dc1916c17ff2b9e430d1585 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Wed, 8 May 2024 14:34:45 +0000 Subject: [PATCH 7/7] fix(router.py): Add missing azure_ad_token param. --- litellm/router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/router.py b/litellm/router.py index 24a926e5d..ea8d17286 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2212,6 +2212,7 @@ class Router: "api_key": api_key, "azure_endpoint": api_base, "api_version": api_version, + "azure_ad_token": azure_ad_token, } from litellm.llms.azure import select_azure_base_url_or_endpoint