mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
fix(acompletion): support client side timeouts + raise exceptions correctly for async calls
This commit is contained in:
parent
43ee581d46
commit
02ed97d0b2
8 changed files with 142 additions and 81 deletions
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Any
|
||||
import types, requests
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
||||
|
@ -98,6 +98,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
|
@ -129,13 +130,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token)
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
else:
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
|
@ -150,16 +151,18 @@ class AzureChatCompletion(BaseLLM):
|
|||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
azure_ad_token: Optional[str]=None, ):
|
||||
response = None
|
||||
try:
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if isinstance(e,httpx.TimeoutException):
|
||||
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
|
||||
elif response and hasattr(response, "text"):
|
||||
elif response is not None and hasattr(response, "text"):
|
||||
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
|
||||
|
@ -171,9 +174,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
):
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
|
||||
response = azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
|
@ -186,8 +190,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None):
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue