(feat) completion: Azure allow users to pass client to router

This commit is contained in:
ishaan-jaff 2023-11-28 15:56:52 -08:00
parent 1a0b683a8e
commit 400a268934
2 changed files with 33 additions and 14 deletions

View file

@ -105,7 +105,7 @@ class AzureChatCompletion(BaseLLM):
logger_fn,
acompletion: bool = False,
headers: Optional[dict]=None,
azure_client = None,
client = None,
):
super().completion()
exception_mapping_worked = False
@ -135,11 +135,11 @@ class AzureChatCompletion(BaseLLM):
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, azure_client=azure_client)
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
else:
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
@ -157,7 +157,10 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
azure_client = AzureOpenAI(**azure_client_params)
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore
response.model = "azure/" + str(response.model)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
@ -176,7 +179,7 @@ class AzureChatCompletion(BaseLLM):
timeout: Any,
model_response: ModelResponse,
azure_ad_token: Optional[str]=None,
azure_client = None, # this is the AsyncAzureOpenAI
client = None, # this is the AsyncAzureOpenAI
):
response = None
try:
@ -196,8 +199,10 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if azure_client is None:
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except AzureOpenAIError as e:
@ -215,6 +220,7 @@ class AzureChatCompletion(BaseLLM):
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None,
client=None,
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
@ -232,8 +238,11 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
azure_client = AzureOpenAI(**azure_client_params)
response = azure_client.chat.completions.create(**data)
if client is None:
azure_client = AzureOpenAI(**azure_client_params)
else:
azure_client = client
response = client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
@ -246,7 +255,9 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None):
azure_ad_token: Optional[str]=None,
client = None,
):
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
@ -260,7 +271,10 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
azure_client = AsyncAzureOpenAI(**azure_client_params)
if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
azure_client = client
response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
@ -276,7 +290,9 @@ class AzureChatCompletion(BaseLLM):
logging_obj=None,
model_response=None,
optional_params=None,
azure_ad_token: Optional[str]=None):
azure_ad_token: Optional[str]=None,
client = None
):
super().embedding()
exception_mapping_worked = False
if self._client_session is None:
@ -304,7 +320,10 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=input,

View file

@ -498,7 +498,7 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
azure_client=optional_params.pop("azure_client", None)
client=optional_params.pop("client", None)
)
## LOGGING