diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 0f155af427..caeef460d2 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -261,6 +261,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): max_retries = DEFAULT_MAX_RETRIES json_mode: Optional[bool] = optional_params.pop("json_mode", False) + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + model_name=model, + api_version=api_version, + ) ### CHECK IF CLOUDFLARE AI GATEWAY ### ### if so - set the model as part of the base url if "gateway.ai.cloudflare.com" in api_base: @@ -321,6 +328,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): timeout=timeout, client=client, max_retries=max_retries, + azure_client_params=azure_client_params, ) else: return self.acompletion( @@ -338,7 +346,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): logging_obj=logging_obj, max_retries=max_retries, convert_tool_call_to_json_mode=json_mode, - litellm_params=litellm_params, + azure_client_params=azure_client_params, ) elif "stream" in optional_params and optional_params["stream"] is True: return self.streaming( @@ -375,28 +383,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): status_code=422, message="max retries must be an int" ) # init AzureOpenAI Client - azure_client_params = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "http_client": litellm.client_session, - "max_retries": max_retries, - "timeout": timeout, - } - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params - ) - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = ( - azure_ad_token_provider - ) - if ( client is None or not isinstance(client, AzureOpenAI) @@ -467,19 +453,10 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): azure_ad_token_provider: Optional[Callable] = None, convert_tool_call_to_json_mode: Optional[bool] = None, client=None, # this is the AsyncAzureOpenAI - litellm_params: Optional[dict] = None, + azure_client_params: dict = {}, ): response = None try: - # init AzureOpenAI Client - azure_client_params = self.initialize_azure_sdk_client( - litellm_params=litellm_params or {}, - api_key=api_key, - api_base=api_base, - model_name=model, - api_version=api_version, - ) - # setting Azure client if client is None or dynamic_params: azure_client = AsyncAzureOpenAI(**azure_client_params) @@ -636,28 +613,9 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): azure_ad_token: Optional[str] = None, azure_ad_token_provider: Optional[Callable] = None, client=None, + azure_client_params: dict = {}, ): try: - # init AzureOpenAI Client - azure_client_params = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "http_client": litellm.aclient_session, - "max_retries": max_retries, - "timeout": timeout, - } - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params - ) - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider if client is None or dynamic_params: azure_client = AsyncAzureOpenAI(**azure_client_params) else: diff --git a/litellm/router.py b/litellm/router.py index f573bf65a6..70ad60f450 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5353,36 +5353,12 @@ class Router: client = self.cache.get_cache( key=cache_key, local_only=True, parent_otel_span=parent_otel_span ) - if client is None: - """ - Re-initialize the client - """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) - client = self.cache.get_cache( - key=cache_key, - local_only=True, - parent_otel_span=parent_otel_span, - ) return client else: cache_key = f"{model_id}_async_client" client = self.cache.get_cache( key=cache_key, local_only=True, parent_otel_span=parent_otel_span ) - # if client is None: - # """ - # Re-initialize the client - # """ - # InitalizeOpenAISDKClient.set_client( - # litellm_router_instance=self, model=deployment - # ) - # client = self.cache.get_cache( - # key=cache_key, - # local_only=True, - # parent_otel_span=parent_otel_span, - # ) return client else: if kwargs.get("stream") is True: @@ -5390,32 +5366,12 @@ class Router: client = self.cache.get_cache( key=cache_key, parent_otel_span=parent_otel_span ) - if client is None: - """ - Re-initialize the client - """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) - client = self.cache.get_cache( - key=cache_key, parent_otel_span=parent_otel_span - ) return client else: cache_key = f"{model_id}_client" client = self.cache.get_cache( key=cache_key, parent_otel_span=parent_otel_span ) - if client is None: - """ - Re-initialize the client - """ - InitalizeOpenAISDKClient.set_client( - litellm_router_instance=self, model=deployment - ) - client = self.cache.get_cache( - key=cache_key, parent_otel_span=parent_otel_span - ) return client def _pre_call_checks( # noqa: PLR0915