diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index c42657d87b..6613ce57ba 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -141,41 +141,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): return headers - def _get_azure_openai_client( - self, - api_version: Optional[str], - api_base: Optional[str], - api_key: Optional[str], - azure_ad_token: Optional[str], - azure_ad_token_provider: Optional[Callable], - model: str, - max_retries: Optional[int], - timeout: Optional[Union[float, httpx.Timeout]], - client: Optional[Any], - client_type: Literal["sync", "async"], - litellm_params: Optional[dict] = None, - ): - # init AzureOpenAI Client - azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client( - litellm_params=litellm_params or {}, - api_key=api_key, - model_name=model, - api_version=api_version, - api_base=api_base, - ) - if client is None: - if client_type == "sync": - azure_client = AzureOpenAI(**azure_client_params) # type: ignore - elif client_type == "async": - azure_client = AsyncAzureOpenAI(**azure_client_params) # type: ignore - else: - azure_client = client - if api_version is not None and isinstance(azure_client._custom_query, dict): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) - - return azure_client - def make_sync_azure_openai_chat_completion_request( self, azure_client: AzureOpenAI, @@ -388,17 +353,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): status_code=422, message="max retries must be an int" ) # init AzureOpenAI Client - azure_client = self._get_azure_openai_client( + azure_client = self.get_azure_openai_client( api_version=api_version, api_base=api_base, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, client=client, - client_type="sync", + _is_async=False, litellm_params=litellm_params, ) if not isinstance(azure_client, AzureOpenAI): @@ -466,17 +427,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): response = None try: # setting Azure client - azure_client = self._get_azure_openai_client( + azure_client = self.get_azure_openai_client( api_version=api_version, api_base=api_base, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, client=client, - client_type="async", + _is_async=True, litellm_params=litellm_params, ) if not isinstance(azure_client, AsyncAzureOpenAI): @@ -589,17 +546,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): elif azure_ad_token_provider is not None: azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider - azure_client = self._get_azure_openai_client( + azure_client = self.get_azure_openai_client( api_version=api_version, api_base=api_base, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, client=client, - client_type="sync", + _is_async=False, litellm_params=litellm_params, ) if not isinstance(azure_client, AzureOpenAI): @@ -652,17 +605,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): litellm_params: Optional[dict] = {}, ): try: - azure_client = self._get_azure_openai_client( + azure_client = self.get_azure_openai_client( api_version=api_version, api_base=api_base, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, client=client, - client_type="async", + _is_async=True, litellm_params=litellm_params, ) if not isinstance(azure_client, AsyncAzureOpenAI): @@ -737,17 +686,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): response = None try: - openai_aclient = self._get_azure_openai_client( + openai_aclient = self.get_azure_openai_client( api_version=api_version, api_base=api_base, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, + _is_async=True, client=client, - client_type="async", litellm_params=litellm_params, ) if not isinstance(openai_aclient, AsyncAzureOpenAI): @@ -846,17 +791,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): client=client, litellm_params=litellm_params, ) - azure_client = self._get_azure_openai_client( + azure_client = self.get_azure_openai_client( api_version=api_version, api_base=api_base, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, + _is_async=False, client=client, - client_type="sync", litellm_params=litellm_params, ) if not isinstance(azure_client, AzureOpenAI): @@ -1315,17 +1256,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): litellm_params=litellm_params, ) # type: ignore - azure_client: AzureOpenAI = self._get_azure_openai_client( + azure_client: AzureOpenAI = self.get_azure_openai_client( api_base=api_base, api_version=api_version, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, + _is_async=False, client=client, - client_type="sync", litellm_params=litellm_params, ) # type: ignore @@ -1354,17 +1291,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): litellm_params: Optional[dict] = None, ) -> HttpxBinaryResponseContent: - azure_client: AsyncAzureOpenAI = self._get_azure_openai_client( + azure_client: AsyncAzureOpenAI = self.get_azure_openai_client( api_base=api_base, api_version=api_version, api_key=api_key, - azure_ad_token=azure_ad_token, - azure_ad_token_provider=azure_ad_token_provider, model=model, - max_retries=max_retries, - timeout=timeout, + _is_async=True, client=client, - client_type="async", litellm_params=litellm_params, ) # type: ignore diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 909fcd88a5..24eb758653 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -247,20 +247,21 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict): class BaseAzureLLM: def get_azure_openai_client( self, - litellm_params: dict, api_key: Optional[str], api_base: Optional[str], api_version: Optional[str] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, _is_async: bool = False, + model: Optional[str] = None, ) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None if client is None: azure_client_params = self.initialize_azure_sdk_client( - litellm_params=litellm_params, + litellm_params=litellm_params or {}, api_key=api_key, api_base=api_base, - model_name="", + model_name=model, api_version=api_version, ) if _is_async is True: @@ -269,6 +270,11 @@ class BaseAzureLLM: openai_client = AzureOpenAI(**azure_client_params) # type: ignore else: openai_client = client + if api_version is not None and isinstance( + openai_client._custom_query, dict + ): + # set api_version to version passed by user + openai_client._custom_query.setdefault("api-version", api_version) return openai_client @@ -277,7 +283,7 @@ class BaseAzureLLM: litellm_params: dict, api_key: Optional[str], api_base: Optional[str], - model_name: str, + model_name: Optional[str], api_version: Optional[str], ) -> dict: