fix(http_handler.py): add retry logic on httpx.remoteprotocolerror

This commit is contained in:
Krrish Dholakia 2024-06-13 14:05:29 -07:00
parent d45db9a5a2
commit 46d57526c4
3 changed files with 29 additions and 32 deletions

View file

@ -26,7 +26,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 401 self.status_code = 401
self.message = message self.message = "litellm.AuthenticationError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -72,7 +72,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 404 self.status_code = 404
self.message = message self.message = "litellm.NotFoundError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -117,7 +117,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.BadRequestError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -162,7 +162,7 @@ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 422 self.status_code = 422
self.message = message self.message = "litellm.UnprocessableEntityError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -204,7 +204,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
request=request request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
self.status_code = 408 self.status_code = 408
self.message = message self.message = "litellm.Timeout: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -241,7 +241,7 @@ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 403 self.status_code = 403
self.message = message self.message = "litellm.PermissionDeniedError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -280,7 +280,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 429 self.status_code = 429
self.message = message self.message = "litellm.RateLimitError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -328,7 +328,7 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.ContextWindowExceededError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -368,7 +368,7 @@ class RejectedRequestError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.RejectedRequestError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -411,7 +411,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = message self.message = "litellm.ContentPolicyViolationError: {}".format(message)
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -452,7 +452,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 503 self.status_code = 503
self.message = message self.message = "litellm.ServiceUnavailableError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -501,7 +501,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = 500 self.status_code = 500
self.message = message self.message = "litellm.InternalServerError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -552,7 +552,7 @@ class APIError(openai.APIError): # type: ignore
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = "litellm.APIError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
@ -589,7 +589,7 @@ class APIConnectionError(openai.APIConnectionError): # type: ignore
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.message = message self.message = "litellm.APIConnectionError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
self.status_code = 500 self.status_code = 500
@ -626,7 +626,7 @@ class APIResponseValidationError(openai.APIResponseValidationError): # type: ig
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
): ):
self.message = message self.message = "litellm.APIResponseValidationError: {}".format(message)
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1") request = httpx.Request(method="POST", url="https://api.openai.com/v1")

View file

@ -5,8 +5,6 @@ from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts # https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
_DEFAULT_CONNECTION_RETRIES = 2 # retries if error connecting to api
class AsyncHTTPHandler: class AsyncHTTPHandler:
def __init__( def __init__(
@ -95,22 +93,21 @@ class AsyncHTTPHandler:
response = await self.client.send(req, stream=stream) response = await self.client.send(req, stream=stream)
response.raise_for_status() response.raise_for_status()
return response return response
except httpx.ConnectError: except httpx.RemoteProtocolError:
# Retry the request with a new session if there is a connection error # Retry the request with a new session if there is a connection error
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1) new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
for _ in range(_DEFAULT_CONNECTION_RETRIES): try:
try: return await self.single_connection_post_request(
return await self.single_connection_post_request( url=url,
url=url, client=new_client,
client=new_client, data=data,
data=data, json=json,
json=json, params=params,
params=params, headers=headers,
headers=headers, stream=stream,
stream=stream, )
) finally:
except httpx.ConnectError: await new_client.aclose()
pass
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise e raise e
except Exception as e: except Exception as e:

View file

@ -11,7 +11,7 @@ model_list:
- model_name: fake-openai-endpoint - model_name: fake-openai-endpoint
litellm_params: litellm_params:
model: predibase/llama-3-8b-instruct model: predibase/llama-3-8b-instruct
api_base: "http://0.0.0.0:8081" api_base: "http://0.0.0.0:8000"
api_key: os.environ/PREDIBASE_API_KEY api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID tenant_id: os.environ/PREDIBASE_TENANT_ID
max_retries: 0 max_retries: 0