fix(http_handler.py): correctly re-raise timeout exception

This commit is contained in:
Krrish Dholakia 2024-07-31 14:51:16 -07:00
parent 4e7d9d2bb1
commit 6202f9bbb0
6 changed files with 43 additions and 23 deletions

View file

@ -199,8 +199,12 @@ class Timeout(openai.APITimeoutError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
headers: Optional[dict] = None,
): ):
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(
method="POST",
url="https://api.openai.com/v1",
)
super().__init__( super().__init__(
request=request request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -211,6 +215,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
self.headers = headers
# custom function to convert to str # custom function to convert to str
def __str__(self): def __str__(self):

View file

@ -84,20 +84,17 @@ class AsyncHTTPHandler:
stream: bool = False, stream: bool = False,
): ):
try: try:
if timeout is not None: if timeout is None:
req = self.client.build_request( timeout = self.timeout
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore req = self.client.build_request(
) "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
else: )
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = await self.client.send(req, stream=stream) response = await self.client.send(req, stream=stream)
response.raise_for_status() response.raise_for_status()
return response return response
except (httpx.RemoteProtocolError, httpx.ConnectError): except (httpx.RemoteProtocolError, httpx.ConnectError):
# Retry the request with a new session if there is a connection error # Retry the request with a new session if there is a connection error
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1) new_client = self.create_client(timeout=timeout, concurrent_limit=1)
try: try:
return await self.single_connection_post_request( return await self.single_connection_post_request(
url=url, url=url,
@ -110,11 +107,17 @@ class AsyncHTTPHandler:
) )
finally: finally:
await new_client.aclose() await new_client.aclose()
except httpx.TimeoutException: except httpx.TimeoutException as e:
headers = {}
if hasattr(e, "response") and e.response is not None:
for key, value in e.response.headers.items():
headers["response_headers-{}".format(key)] = value
raise litellm.Timeout( raise litellm.Timeout(
message=f"Connection timed out after {timeout} seconds.", message=f"Connection timed out after {timeout} seconds.",
model="default-model-name", model="default-model-name",
llm_provider="litellm-httpx-handler", llm_provider="litellm-httpx-handler",
headers=headers,
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code) setattr(e, "status_code", e.response.status_code)

View file

@ -362,6 +362,15 @@ class PredibaseChatCompletion(BaseLLM):
total_tokens=total_tokens, total_tokens=total_tokens,
) )
model_response.usage = usage # type: ignore model_response.usage = usage # type: ignore
## RESPONSE HEADERS
predibase_headers = response.headers
response_headers = {}
for k, v in predibase_headers.items():
if k.startswith("x-"):
response_headers["llm_provider-{}".format(k)] = v
model_response._hidden_params["additional_headers"] = response_headers
return model_response return model_response
def completion( def completion(
@ -550,6 +559,9 @@ class PredibaseChatCompletion(BaseLLM):
), ),
) )
except Exception as e: except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise PredibaseError( raise PredibaseError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc()) status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
) )

View file

@ -1,5 +1,4 @@
model_list: model_list:
- model_name: claude-3-haiku-20240307 - model_name: "*"
litellm_params: litellm_params:
model: anthropic/claude-3-haiku-20240307 model: "*"
max_tokens: 4096

View file

@ -3069,6 +3069,7 @@ async def chat_completion(
type=getattr(e, "type", "None"), type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"), param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500), code=getattr(e, "status_code", 500),
headers=getattr(e, "headers", {}),
) )

View file

@ -261,16 +261,16 @@ async def test_completion_predibase():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
with patch("requests.post", side_effect=predibase_mock_post): # with patch("requests.post", side_effect=predibase_mock_post):
response = completion( response = await litellm.acompletion(
model="predibase/llama-3-8b-instruct", model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95", tenant_id="c4768f95",
api_key=os.getenv("PREDIBASE_API_KEY"), api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}], messages=[{"role": "user", "content": "What is the meaning of life?"}],
max_tokens=10, max_tokens=10,
) )
print(response) print(response)
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e: