feat(openai.py): bubble all error information back to client

This commit is contained in:
Krrish Dholakia 2025-03-10 15:27:43 -07:00
parent c1ec82fbd5
commit 5f87dc229a
4 changed files with 18 additions and 5 deletions

View file

@ -118,6 +118,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
body: Optional[dict] = None,
): ):
self.status_code = 400 self.status_code = 400
self.message = "litellm.BadRequestError: {}".format(message) self.message = "litellm.BadRequestError: {}".format(message)
@ -133,7 +134,7 @@ class BadRequestError(openai.BadRequestError): # type: ignore
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=body
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):

View file

@ -331,6 +331,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model, model=model,
response=getattr(original_exception, "response", None), response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
body=getattr(original_exception, "body", None),
) )
elif ( elif (
"Web server is returning an unknown error" in error_str "Web server is returning an unknown error" in error_str

View file

@ -732,10 +732,14 @@ class OpenAIChatCompletion(BaseLLM):
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e)) error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None) error_response = getattr(e, "response", None)
error_body = getattr(e, "body", None)
if error_headers is None and error_response: if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None) error_headers = getattr(error_response, "headers", None)
raise OpenAIError( raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers status_code=status_code,
message=error_text,
headers=error_headers,
body=error_body,
) )
async def acompletion( async def acompletion(
@ -977,6 +981,7 @@ class OpenAIChatCompletion(BaseLLM):
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
status_code = getattr(e, "status_code", 500) status_code = getattr(e, "status_code", 500)
error_response = getattr(e, "response", None) error_response = getattr(e, "response", None)
exception_body = getattr(e, "body", None)
if error_headers is None and error_response: if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None) error_headers = getattr(error_response, "headers", None)
if response is not None and hasattr(response, "text"): if response is not None and hasattr(response, "text"):
@ -984,6 +989,7 @@ class OpenAIChatCompletion(BaseLLM):
status_code=status_code, status_code=status_code,
message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
headers=error_headers, headers=error_headers,
body=exception_body,
) )
else: else:
if type(e).__name__ == "ReadTimeout": if type(e).__name__ == "ReadTimeout":
@ -991,16 +997,21 @@ class OpenAIChatCompletion(BaseLLM):
status_code=408, status_code=408,
message=f"{type(e).__name__}", message=f"{type(e).__name__}",
headers=error_headers, headers=error_headers,
body=exception_body,
) )
elif hasattr(e, "status_code"): elif hasattr(e, "status_code"):
raise OpenAIError( raise OpenAIError(
status_code=getattr(e, "status_code", 500), status_code=getattr(e, "status_code", 500),
message=str(e), message=str(e),
headers=error_headers, headers=error_headers,
body=exception_body,
) )
else: else:
raise OpenAIError( raise OpenAIError(
status_code=500, message=f"{str(e)}", headers=error_headers status_code=500,
message=f"{str(e)}",
headers=error_headers,
body=exception_body,
) )
def get_stream_options( def get_stream_options(

View file

@ -419,6 +419,6 @@ async def test_exception_bubbling_up(sync_mode, stream_mode):
sync_stream=sync_mode, sync_stream=sync_mode,
) )
assert exc_info.value.code == "invalid_request_error" assert exc_info.value.code == "invalid_value"
assert exc_info.value.param == "messages" assert exc_info.value.param is not None
assert exc_info.value.type == "invalid_request_error" assert exc_info.value.type == "invalid_request_error"