diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 84e02bbf95..d0875412f6 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -1152,6 +1152,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): azure_ad_token_provider: Optional[Callable] = None, client=None, aimg_generation=None, + litellm_params: Optional[dict] = None, ) -> ImageResponse: try: if model and len(model) > 0: @@ -1176,25 +1177,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): ) # init AzureOpenAI Client - azure_client_params: Dict[str, Any] = { - "api_version": api_version, - "azure_endpoint": api_base, - "azure_deployment": model, - "max_retries": max_retries, - "timeout": timeout, - } - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params=azure_client_params + azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + model_name=model or "", + api_version=api_version, + api_base=api_base, ) - if api_key is not None: - azure_client_params["api_key"] = api_key - elif azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - azure_client_params["azure_ad_token"] = azure_ad_token - elif azure_ad_token_provider is not None: - azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider - if aimg_generation is True: return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore diff --git a/litellm/main.py b/litellm/main.py index 997c1ae75d..b0a4268106 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4544,6 +4544,8 @@ def image_generation( # noqa: PLR0915 **non_default_params, ) + litellm_params_dict = get_litellm_params(**kwargs) + logging: Logging = litellm_logging_obj logging.update_environment_variables( model=model, @@ -4614,6 +4616,7 @@ def image_generation( # noqa: PLR0915 aimg_generation=aimg_generation, client=client, headers=headers, + litellm_params=litellm_params_dict, ) elif ( custom_llm_provider == "openai"