mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
(feat) completion: Azure allow users to pass client to router
This commit is contained in:
parent
a0a6d5ceeb
commit
a2d7623d6e
2 changed files with 33 additions and 14 deletions
|
@ -105,7 +105,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
logger_fn,
|
logger_fn,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
headers: Optional[dict]=None,
|
headers: Optional[dict]=None,
|
||||||
azure_client = None,
|
client = None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
|
@ -135,11 +135,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, azure_client=azure_client)
|
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout, client=client)
|
||||||
else:
|
else:
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
|
@ -157,7 +157,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
if client is None:
|
||||||
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||||
response.model = "azure/" + str(response.model)
|
response.model = "azure/" + str(response.model)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
|
@ -176,7 +179,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
azure_ad_token: Optional[str]=None,
|
azure_ad_token: Optional[str]=None,
|
||||||
azure_client = None, # this is the AsyncAzureOpenAI
|
client = None, # this is the AsyncAzureOpenAI
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -196,8 +199,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if azure_client is None:
|
if client is None:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
response = await azure_client.chat.completions.create(**data)
|
response = await azure_client.chat.completions.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
|
@ -215,6 +220,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
azure_ad_token: Optional[str]=None,
|
azure_ad_token: Optional[str]=None,
|
||||||
|
client=None,
|
||||||
):
|
):
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
|
@ -232,8 +238,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
if client is None:
|
||||||
response = azure_client.chat.completions.create(**data)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
|
response = client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
for transformed_chunk in streamwrapper:
|
for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
|
@ -246,7 +255,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
azure_ad_token: Optional[str]=None):
|
azure_ad_token: Optional[str]=None,
|
||||||
|
client = None,
|
||||||
|
):
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params = {
|
azure_client_params = {
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
|
@ -260,7 +271,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
if client is None:
|
||||||
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
response = await azure_client.chat.completions.create(**data)
|
response = await azure_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
|
@ -276,7 +290,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
model_response=None,
|
model_response=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
azure_ad_token: Optional[str]=None):
|
azure_ad_token: Optional[str]=None,
|
||||||
|
client = None
|
||||||
|
):
|
||||||
super().embedding()
|
super().embedding()
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
if self._client_session is None:
|
if self._client_session is None:
|
||||||
|
@ -304,7 +320,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
if client is None:
|
||||||
|
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
|
|
@ -498,7 +498,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
azure_client=optional_params.pop("azure_client", None)
|
client=optional_params.pop("client", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue