fix(acompletion): support client side timeouts + raise exceptions correctly for async calls

This commit is contained in:
Krrish Dholakia 2023-11-17 15:39:39 -08:00
parent c4e53aa77b
commit 0ab6b2451d
8 changed files with 142 additions and 81 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Any
import types, time, json
import httpx
from .base import BaseLLM
@ -161,6 +161,7 @@ class OpenAIChatCompletion(BaseLLM):
def completion(self,
model_response: ModelResponse,
timeout: Any,
model: Optional[str]=None,
messages: Optional[list]=None,
print_verbose: Optional[Callable]=None,
@ -180,6 +181,9 @@ class OpenAIChatCompletion(BaseLLM):
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float):
raise OpenAIError(status_code=422, message=f"Timeout needs to be a float")
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
data = {
"model": model,
@ -197,13 +201,13 @@ class OpenAIChatCompletion(BaseLLM):
try:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key)
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
else:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
response = openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
@ -234,27 +238,32 @@ class OpenAIChatCompletion(BaseLLM):
async def acompletion(self,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self,
logging_obj,
timeout: float,
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None
):
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
@ -262,16 +271,27 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(self,
logging_obj,
timeout: float,
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def embedding(self,
model: str,
input: list,