From 7914623fbca2e521ffdd9dff061048e1f388197f Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 28 Nov 2023 15:19:50 -0800 Subject: [PATCH] (feat) allow users to pass azure client for acmompletion --- litellm/llms/azure.py | 13 +++++++++---- litellm/main.py | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 122c874adc..90030bc1dd 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -104,7 +104,9 @@ class AzureChatCompletion(BaseLLM): litellm_params, logger_fn, acompletion: bool = False, - headers: Optional[dict]=None): + headers: Optional[dict]=None, + azure_client = None, + ): super().completion() exception_mapping_worked = False try: @@ -135,7 +137,7 @@ class AzureChatCompletion(BaseLLM): if optional_params.get("stream", False): return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout) else: - return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout) + return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, azure_client=azure_client) elif "stream" in optional_params and optional_params["stream"] == True: return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout) else: @@ -173,7 +175,9 @@ class AzureChatCompletion(BaseLLM): data: dict, timeout: Any, model_response: ModelResponse, - azure_ad_token: Optional[str]=None, ): + azure_ad_token: Optional[str]=None, + azure_client = None, # this is the AsyncAzureOpenAI + ): response = None try: max_retries = data.pop("max_retries", 2) @@ -192,7 +196,8 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - azure_client = AsyncAzureOpenAI(**azure_client_params) + if azure_client is None: + azure_client = AsyncAzureOpenAI(**azure_client_params) response = await azure_client.chat.completions.create(**data) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except AzureOpenAIError as e: diff --git a/litellm/main.py b/litellm/main.py index 3f096535d6..4656f7c98a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -497,7 +497,8 @@ def completion( logger_fn=logger_fn, logging_obj=logging, acompletion=acompletion, - timeout=timeout + timeout=timeout, + azure_client=optional_params.pop("azure_client", None) ) ## LOGGING