mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #5358 from BerriAI/litellm_fix_retry_after
fix retry after - cooldown individual models based on their specific 'retry-after' header
This commit is contained in:
commit
415abc86c6
12 changed files with 754 additions and 202 deletions
|
@ -75,9 +75,11 @@ class AzureOpenAIError(Exception):
|
|||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
|
@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(
|
||||
|
@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
convert_tool_call_to_json_mode=json_mode,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
|
@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
|
@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout)
|
||||
response = await openai_aclient.embeddings.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
stringified_response = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
aembedding=None,
|
||||
):
|
||||
super().embedding()
|
||||
exception_mapping_worked = False
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
try:
|
||||
|
@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
azure_client = client
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||
response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
|
@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def make_async_azure_httpx_request(
|
||||
self,
|
||||
|
|
|
@ -33,9 +33,11 @@ class AzureOpenAIError(Exception):
|
|||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
|
@ -311,13 +313,13 @@ class AzureTextCompletion(BaseLLM):
|
|||
)
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
|
@ -387,10 +389,11 @@ class AzureTextCompletion(BaseLLM):
|
|||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise e
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
|
@ -443,7 +446,9 @@ class AzureTextCompletion(BaseLLM):
|
|||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = azure_client.completions.create(**data, timeout=timeout)
|
||||
response = azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
@ -501,7 +506,9 @@ class AzureTextCompletion(BaseLLM):
|
|||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = await azure_client.completions.create(**data, timeout=timeout)
|
||||
response = await azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
# return response
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
|
@ -511,7 +518,8 @@ class AzureTextCompletion(BaseLLM):
|
|||
)
|
||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
|
|
@ -50,9 +50,11 @@ class OpenAIError(Exception):
|
|||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
|
@ -113,7 +115,7 @@ class MistralConfig:
|
|||
random_seed: Optional[int] = None,
|
||||
safe_prompt: Optional[bool] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
stop: Optional[Union[str, list]] = None
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
|
@ -172,7 +174,7 @@ class MistralConfig:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
optional_params["stop"] = value
|
||||
if param == "tool_choice" and isinstance(value, str):
|
||||
optional_params["tool_choice"] = self._map_tool_choice(
|
||||
tool_choice=value
|
||||
|
@ -768,7 +770,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_aclient: AsyncOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
) -> Tuple[dict, BaseModel]:
|
||||
"""
|
||||
Helper to:
|
||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||
|
@ -781,39 +783,51 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
)
|
||||
|
||||
headers = dict(raw_response.headers)
|
||||
if hasattr(raw_response, "headers"):
|
||||
headers = dict(raw_response.headers)
|
||||
else:
|
||||
headers = {}
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
except OpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def make_sync_openai_chat_completion_request(
|
||||
self,
|
||||
openai_client: OpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
) -> Tuple[dict, BaseModel]:
|
||||
"""
|
||||
Helper to:
|
||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||
- call chat.completions.create by default
|
||||
"""
|
||||
try:
|
||||
if litellm.return_response_headers is True:
|
||||
raw_response = openai_client.chat.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
raw_response = openai_client.chat.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
|
||||
if hasattr(raw_response, "headers"):
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
else:
|
||||
response = openai_client.chat.completions.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
return None, response
|
||||
except Exception as e:
|
||||
headers = {}
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except OpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -1260,6 +1274,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
except (
|
||||
Exception
|
||||
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
||||
if isinstance(e, OpenAIError):
|
||||
raise e
|
||||
if response is not None and hasattr(response, "text"):
|
||||
raise OpenAIError(
|
||||
status_code=500,
|
||||
|
@ -1288,16 +1304,12 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
- call embeddings.create by default
|
||||
"""
|
||||
try:
|
||||
if litellm.return_response_headers is True:
|
||||
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
else:
|
||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||
return None, response
|
||||
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1313,17 +1325,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
- call embeddings.create by default
|
||||
"""
|
||||
try:
|
||||
if litellm.return_response_headers is True:
|
||||
raw_response = openai_client.embeddings.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
) # type: ignore
|
||||
raw_response = openai_client.embeddings.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
) # type: ignore
|
||||
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
else:
|
||||
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||
return None, response
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1367,14 +1375,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
response_type="embedding",
|
||||
_response_headers=headers,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
except OpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
|
@ -1448,13 +1456,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
response_type="embedding",
|
||||
) # type: ignore
|
||||
except OpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code"):
|
||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
||||
else:
|
||||
raise OpenAIError(status_code=500, message=str(e))
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
|
@ -1975,7 +1983,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
|
@ -2019,7 +2027,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
else:
|
||||
openai_client = client
|
||||
|
||||
response = openai_client.completions.create(**data) # type: ignore
|
||||
response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
||||
|
||||
response_json = response.model_dump()
|
||||
|
||||
|
@ -2067,7 +2075,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
|
||||
response = await openai_aclient.completions.create(**data)
|
||||
response = await openai_aclient.completions.with_raw_response.create(**data)
|
||||
response_json = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -2100,6 +2108,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
|
@ -2111,7 +2120,15 @@ class OpenAITextCompletion(BaseLLM):
|
|||
)
|
||||
else:
|
||||
openai_client = client
|
||||
response = openai_client.completions.create(**data)
|
||||
|
||||
try:
|
||||
response = openai_client.completions.with_raw_response.create(**data)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
@ -2149,7 +2166,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
else:
|
||||
openai_client = client
|
||||
|
||||
response = await openai_client.completions.create(**data)
|
||||
response = await openai_client.completions.with_raw_response.create(**data)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue