fix(utils.py): correctly re-raise the headers from an exception, if present

Fixes issue where retry after on router was not using azure / openai numbers
This commit is contained in:
Krrish Dholakia 2024-08-24 12:30:30 -07:00
parent 5a2c9d5121
commit 068aafdff9
6 changed files with 228 additions and 33 deletions

View file

@ -50,9 +50,11 @@ class OpenAIError(Exception):
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
@ -113,7 +115,7 @@ class MistralConfig:
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
@ -172,7 +174,7 @@ class MistralConfig:
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
@ -1313,17 +1315,13 @@ class OpenAIChatCompletion(BaseLLM):
- call embeddings.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = openai_client.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
raw_response = openai_client.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
return None, response
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
@ -1448,13 +1446,13 @@ class OpenAIChatCompletion(BaseLLM):
response_type="embedding",
) # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=str(e))
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def aimage_generation(
self,