forked from phoenix/litellm-mirror
fix(http_handler.py): correctly re-raise timeout exception
This commit is contained in:
parent
4e7d9d2bb1
commit
6202f9bbb0
6 changed files with 43 additions and 23 deletions
|
@ -199,8 +199,12 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.openai.com/v1",
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
request=request
|
request=request
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
@ -211,6 +215,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
# custom function to convert to str
|
# custom function to convert to str
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
|
@ -84,20 +84,17 @@ class AsyncHTTPHandler:
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if timeout is not None:
|
if timeout is None:
|
||||||
req = self.client.build_request(
|
timeout = self.timeout
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
req = self.client.build_request(
|
||||||
)
|
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||||
else:
|
)
|
||||||
req = self.client.build_request(
|
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
|
||||||
)
|
|
||||||
response = await self.client.send(req, stream=stream)
|
response = await self.client.send(req, stream=stream)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
||||||
# Retry the request with a new session if there is a connection error
|
# Retry the request with a new session if there is a connection error
|
||||||
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
|
new_client = self.create_client(timeout=timeout, concurrent_limit=1)
|
||||||
try:
|
try:
|
||||||
return await self.single_connection_post_request(
|
return await self.single_connection_post_request(
|
||||||
url=url,
|
url=url,
|
||||||
|
@ -110,11 +107,17 @@ class AsyncHTTPHandler:
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
await new_client.aclose()
|
await new_client.aclose()
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException as e:
|
||||||
|
headers = {}
|
||||||
|
if hasattr(e, "response") and e.response is not None:
|
||||||
|
for key, value in e.response.headers.items():
|
||||||
|
headers["response_headers-{}".format(key)] = value
|
||||||
|
|
||||||
raise litellm.Timeout(
|
raise litellm.Timeout(
|
||||||
message=f"Connection timed out after {timeout} seconds.",
|
message=f"Connection timed out after {timeout} seconds.",
|
||||||
model="default-model-name",
|
model="default-model-name",
|
||||||
llm_provider="litellm-httpx-handler",
|
llm_provider="litellm-httpx-handler",
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
setattr(e, "status_code", e.response.status_code)
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
|
|
@ -362,6 +362,15 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage # type: ignore
|
model_response.usage = usage # type: ignore
|
||||||
|
|
||||||
|
## RESPONSE HEADERS
|
||||||
|
predibase_headers = response.headers
|
||||||
|
response_headers = {}
|
||||||
|
for k, v in predibase_headers.items():
|
||||||
|
if k.startswith("x-"):
|
||||||
|
response_headers["llm_provider-{}".format(k)] = v
|
||||||
|
|
||||||
|
model_response._hidden_params["additional_headers"] = response_headers
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -550,6 +559,9 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||||
|
if isinstance(e, exception):
|
||||||
|
raise e
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: claude-3-haiku-20240307
|
- model_name: "*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: anthropic/claude-3-haiku-20240307
|
model: "*"
|
||||||
max_tokens: 4096
|
|
|
@ -3069,6 +3069,7 @@ async def chat_completion(
|
||||||
type=getattr(e, "type", "None"),
|
type=getattr(e, "type", "None"),
|
||||||
param=getattr(e, "param", "None"),
|
param=getattr(e, "param", "None"),
|
||||||
code=getattr(e, "status_code", 500),
|
code=getattr(e, "status_code", 500),
|
||||||
|
headers=getattr(e, "headers", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -261,16 +261,16 @@ async def test_completion_predibase():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
with patch("requests.post", side_effect=predibase_mock_post):
|
# with patch("requests.post", side_effect=predibase_mock_post):
|
||||||
response = completion(
|
response = await litellm.acompletion(
|
||||||
model="predibase/llama-3-8b-instruct",
|
model="predibase/llama-3-8b-instruct",
|
||||||
tenant_id="c4768f95",
|
tenant_id="c4768f95",
|
||||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue