mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
(feat) completion: Azure allow users to pass client to router
This commit is contained in:
parent
1a0b683a8e
commit
400a268934
2 changed files with 33 additions and 14 deletions
|
@ -105,7 +105,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
logger_fn,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict]=None,
|
||||
azure_client = None,
|
||||
client = None,
|
||||
):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
|
@ -135,11 +135,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, azure_client=azure_client)
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||
else:
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
|
@ -157,7 +157,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
response.model = "azure/" + str(response.model)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
|
@ -176,7 +179,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
azure_client = None, # this is the AsyncAzureOpenAI
|
||||
client = None, # this is the AsyncAzureOpenAI
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -196,8 +199,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if azure_client is None:
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
|
@ -215,6 +220,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client=None,
|
||||
):
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
|
@ -232,8 +238,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
response = azure_client.chat.completions.create(**data)
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
response = client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
@ -246,7 +255,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None):
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client = None,
|
||||
):
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
|
@ -260,7 +271,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
@ -276,7 +290,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
azure_ad_token: Optional[str]=None):
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client = None
|
||||
):
|
||||
super().embedding()
|
||||
exception_mapping_worked = False
|
||||
if self._client_session is None:
|
||||
|
@ -304,7 +320,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
|
|
@ -498,7 +498,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
azure_client=optional_params.pop("azure_client", None)
|
||||
client=optional_params.pop("client", None)
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue