From 400a2689342736a2a5c5b5fad650759365b1840b Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 28 Nov 2023 15:56:52 -0800 Subject: [PATCH] (feat) completion: Azure allow users to pass client to router --- litellm/llms/azure.py | 45 ++++++++++++++++++++++++++++++------------- litellm/main.py | 2 +- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 90030bc1dd..1066bca3f2 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -105,7 +105,7 @@ class AzureChatCompletion(BaseLLM): logger_fn, acompletion: bool = False, headers: Optional[dict]=None, - azure_client = None, + client = None, ): super().completion() exception_mapping_worked = False @@ -135,11 +135,11 @@ class AzureChatCompletion(BaseLLM): ) if acompletion is True: if optional_params.get("stream", False): - return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout) + return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) else: - return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, azure_client=azure_client) + return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client) elif "stream" in optional_params and optional_params["stream"] == True: - return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout) + return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client) else: max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): @@ -157,7 +157,10 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - azure_client = AzureOpenAI(**azure_client_params) + if client is None: + azure_client = AzureOpenAI(**azure_client_params) + else: + azure_client = client response = azure_client.chat.completions.create(**data) # type: ignore response.model = "azure/" + str(response.model) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) @@ -176,7 +179,7 @@ class AzureChatCompletion(BaseLLM): timeout: Any, model_response: ModelResponse, azure_ad_token: Optional[str]=None, - azure_client = None, # this is the AsyncAzureOpenAI + client = None, # this is the AsyncAzureOpenAI ): response = None try: @@ -196,8 +199,10 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - if azure_client is None: + if client is None: azure_client = AsyncAzureOpenAI(**azure_client_params) + else: + azure_client = client response = await azure_client.chat.completions.create(**data) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) except AzureOpenAIError as e: @@ -215,6 +220,7 @@ class AzureChatCompletion(BaseLLM): model: str, timeout: Any, azure_ad_token: Optional[str]=None, + client=None, ): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): @@ -232,8 +238,11 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - azure_client = AzureOpenAI(**azure_client_params) - response = azure_client.chat.completions.create(**data) + if client is None: + azure_client = AzureOpenAI(**azure_client_params) + else: + azure_client = client + response = client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) for transformed_chunk in streamwrapper: yield transformed_chunk @@ -246,7 +255,9 @@ class AzureChatCompletion(BaseLLM): data: dict, model: str, timeout: Any, - azure_ad_token: Optional[str]=None): + azure_ad_token: Optional[str]=None, + client = None, + ): # init AzureOpenAI Client azure_client_params = { "api_version": api_version, @@ -260,7 +271,10 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - azure_client = AsyncAzureOpenAI(**azure_client_params) + if client is None: + azure_client = AsyncAzureOpenAI(**azure_client_params) + else: + azure_client = client response = await azure_client.chat.completions.create(**data) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) async for transformed_chunk in streamwrapper: @@ -276,7 +290,9 @@ class AzureChatCompletion(BaseLLM): logging_obj=None, model_response=None, optional_params=None, - azure_ad_token: Optional[str]=None): + azure_ad_token: Optional[str]=None, + client = None + ): super().embedding() exception_mapping_worked = False if self._client_session is None: @@ -304,7 +320,10 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - azure_client = AzureOpenAI(**azure_client_params) # type: ignore + if client is None: + azure_client = AzureOpenAI(**azure_client_params) # type: ignore + else: + azure_client = client ## LOGGING logging_obj.pre_call( input=input, diff --git a/litellm/main.py b/litellm/main.py index 4656f7c98a..2ccd19683b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -498,7 +498,7 @@ def completion( logging_obj=logging, acompletion=acompletion, timeout=timeout, - azure_client=optional_params.pop("azure_client", None) + client=optional_params.pop("client", None) ) ## LOGGING