forked from phoenix/litellm-mirror
Merge pull request #5358 from BerriAI/litellm_fix_retry_after
fix retry after - cooldown individual models based on their specific 'retry-after' header
This commit is contained in:
commit
415abc86c6
12 changed files with 754 additions and 202 deletions
|
@ -1,12 +1,12 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
# - id: mypy
|
||||||
name: mypy
|
# name: mypy
|
||||||
entry: python3 -m mypy --ignore-missing-imports
|
# entry: python3 -m mypy --ignore-missing-imports
|
||||||
language: system
|
# language: system
|
||||||
types: [python]
|
# types: [python]
|
||||||
files: ^litellm/
|
# files: ^litellm/
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
entry: isort
|
entry: isort
|
||||||
|
|
|
@ -75,9 +75,11 @@ class AzureOpenAIError(Exception):
|
||||||
message,
|
message,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[httpx.Headers] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.headers = headers
|
||||||
if request:
|
if request:
|
||||||
self.request = request
|
self.request = request
|
||||||
else:
|
else:
|
||||||
|
@ -593,7 +595,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
super().completion()
|
super().completion()
|
||||||
exception_mapping_worked = False
|
|
||||||
try:
|
try:
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
|
@ -755,13 +756,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
convert_tool_call_to_json_mode=json_mode,
|
convert_tool_call_to_json_mode=json_mode,
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self,
|
self,
|
||||||
|
@ -1005,10 +1006,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
@ -1027,7 +1029,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout)
|
response = await openai_aclient.embeddings.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -1067,7 +1071,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
):
|
):
|
||||||
super().embedding()
|
super().embedding()
|
||||||
exception_mapping_worked = False
|
|
||||||
if self._client_session is None:
|
if self._client_session is None:
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
|
@ -1127,7 +1130,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = azure_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1138,13 +1141,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
|
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def make_async_azure_httpx_request(
|
async def make_async_azure_httpx_request(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -33,9 +33,11 @@ class AzureOpenAIError(Exception):
|
||||||
message,
|
message,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[httpx.Headers] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.headers = headers
|
||||||
if request:
|
if request:
|
||||||
self.request = request
|
self.request = request
|
||||||
else:
|
else:
|
||||||
|
@ -311,13 +313,13 @@ class AzureTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self,
|
self,
|
||||||
|
@ -387,10 +389,11 @@ class AzureTextCompletion(BaseLLM):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise e
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
|
@ -443,7 +446,9 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = azure_client.completions.create(**data, timeout=timeout)
|
response = azure_client.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -501,7 +506,9 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.completions.create(**data, timeout=timeout)
|
response = await azure_client.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
# return response
|
# return response
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
@ -511,7 +518,8 @@ class AzureTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise AzureOpenAIError(
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
|
@ -50,9 +50,11 @@ class OpenAIError(Exception):
|
||||||
message,
|
message,
|
||||||
request: Optional[httpx.Request] = None,
|
request: Optional[httpx.Request] = None,
|
||||||
response: Optional[httpx.Response] = None,
|
response: Optional[httpx.Response] = None,
|
||||||
|
headers: Optional[httpx.Headers] = None,
|
||||||
):
|
):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.headers = headers
|
||||||
if request:
|
if request:
|
||||||
self.request = request
|
self.request = request
|
||||||
else:
|
else:
|
||||||
|
@ -113,7 +115,7 @@ class MistralConfig:
|
||||||
random_seed: Optional[int] = None,
|
random_seed: Optional[int] = None,
|
||||||
safe_prompt: Optional[bool] = None,
|
safe_prompt: Optional[bool] = None,
|
||||||
response_format: Optional[dict] = None,
|
response_format: Optional[dict] = None,
|
||||||
stop: Optional[Union[str, list]] = None
|
stop: Optional[Union[str, list]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals().copy()
|
locals_ = locals().copy()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -768,7 +770,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
openai_aclient: AsyncOpenAI,
|
openai_aclient: AsyncOpenAI,
|
||||||
data: dict,
|
data: dict,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
):
|
) -> Tuple[dict, BaseModel]:
|
||||||
"""
|
"""
|
||||||
Helper to:
|
Helper to:
|
||||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
@ -781,39 +783,51 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = dict(raw_response.headers)
|
if hasattr(raw_response, "headers"):
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
return headers, response
|
return headers, response
|
||||||
except Exception as e:
|
except OpenAIError as e:
|
||||||
raise e
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
def make_sync_openai_chat_completion_request(
|
def make_sync_openai_chat_completion_request(
|
||||||
self,
|
self,
|
||||||
openai_client: OpenAI,
|
openai_client: OpenAI,
|
||||||
data: dict,
|
data: dict,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
):
|
) -> Tuple[dict, BaseModel]:
|
||||||
"""
|
"""
|
||||||
Helper to:
|
Helper to:
|
||||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
- call chat.completions.create by default
|
- call chat.completions.create by default
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if litellm.return_response_headers is True:
|
raw_response = openai_client.chat.completions.with_raw_response.create(
|
||||||
raw_response = openai_client.chat.completions.with_raw_response.create(
|
**data, timeout=timeout
|
||||||
**data, timeout=timeout
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
if hasattr(raw_response, "headers"):
|
||||||
headers = dict(raw_response.headers)
|
headers = dict(raw_response.headers)
|
||||||
response = raw_response.parse()
|
|
||||||
return headers, response
|
|
||||||
else:
|
else:
|
||||||
response = openai_client.chat.completions.create(
|
headers = {}
|
||||||
**data, timeout=timeout
|
response = raw_response.parse()
|
||||||
)
|
return headers, response
|
||||||
return None, response
|
except OpenAIError as e:
|
||||||
except Exception as e:
|
|
||||||
raise e
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -1260,6 +1274,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
except (
|
except (
|
||||||
Exception
|
Exception
|
||||||
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
) as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
||||||
|
if isinstance(e, OpenAIError):
|
||||||
|
raise e
|
||||||
if response is not None and hasattr(response, "text"):
|
if response is not None and hasattr(response, "text"):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -1288,16 +1304,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
- call embeddings.create by default
|
- call embeddings.create by default
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if litellm.return_response_headers is True:
|
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||||
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
**data, timeout=timeout
|
||||||
**data, timeout=timeout
|
) # type: ignore
|
||||||
) # type: ignore
|
headers = dict(raw_response.headers)
|
||||||
headers = dict(raw_response.headers)
|
response = raw_response.parse()
|
||||||
response = raw_response.parse()
|
return headers, response
|
||||||
return headers, response
|
|
||||||
else:
|
|
||||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
|
||||||
return None, response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -1313,17 +1325,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
- call embeddings.create by default
|
- call embeddings.create by default
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if litellm.return_response_headers is True:
|
raw_response = openai_client.embeddings.with_raw_response.create(
|
||||||
raw_response = openai_client.embeddings.with_raw_response.create(
|
**data, timeout=timeout
|
||||||
**data, timeout=timeout
|
) # type: ignore
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
headers = dict(raw_response.headers)
|
headers = dict(raw_response.headers)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
return headers, response
|
return headers, response
|
||||||
else:
|
|
||||||
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore
|
|
||||||
return None, response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -1367,14 +1375,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
response_type="embedding",
|
response_type="embedding",
|
||||||
_response_headers=headers,
|
_response_headers=headers,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
except Exception as e:
|
except OpenAIError as e:
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
original_response=str(e),
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
self,
|
self,
|
||||||
|
@ -1448,13 +1456,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
response_type="embedding",
|
response_type="embedding",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
except OpenAIError as e:
|
except OpenAIError as e:
|
||||||
exception_mapping_worked = True
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
status_code = getattr(e, "status_code", 500)
|
||||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
error_headers = getattr(e, "headers", None)
|
||||||
else:
|
raise OpenAIError(
|
||||||
raise OpenAIError(status_code=500, message=str(e))
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
|
|
||||||
async def aimage_generation(
|
async def aimage_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -1975,7 +1983,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -2019,7 +2027,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
|
||||||
response = openai_client.completions.create(**data) # type: ignore
|
response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
||||||
|
|
||||||
response_json = response.model_dump()
|
response_json = response.model_dump()
|
||||||
|
|
||||||
|
@ -2067,7 +2075,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
|
||||||
response = await openai_aclient.completions.create(**data)
|
response = await openai_aclient.completions.with_raw_response.create(**data)
|
||||||
response_json = response.model_dump()
|
response_json = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -2100,6 +2108,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
organization=None,
|
organization=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
openai_client = OpenAI(
|
openai_client = OpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -2111,7 +2120,15 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
response = openai_client.completions.create(**data)
|
|
||||||
|
try:
|
||||||
|
response = openai_client.completions.with_raw_response.create(**data)
|
||||||
|
except Exception as e:
|
||||||
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
raise OpenAIError(
|
||||||
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2149,7 +2166,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
|
||||||
response = await openai_client.completions.create(**data)
|
response = await openai_client.completions.with_raw_response.create(**data)
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
|
|
@ -452,7 +452,12 @@ async def _async_streaming(response, model, custom_llm_provider, args):
|
||||||
print_verbose(f"line in async streaming: {line}")
|
print_verbose(f"line in async streaming: {line}")
|
||||||
yield line
|
yield line
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
|
raise exception_type(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_completion(
|
def mock_completion(
|
||||||
|
@ -3765,7 +3770,7 @@ async def atext_completion(
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
if kwargs.get("stream", False) == True: # return an async generator
|
if kwargs.get("stream", False) is True: # return an async generator
|
||||||
return TextCompletionStreamWrapper(
|
return TextCompletionStreamWrapper(
|
||||||
completion_stream=_async_streaming(
|
completion_stream=_async_streaming(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -3774,6 +3779,7 @@ async def atext_completion(
|
||||||
args=args,
|
args=args,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
|
@ -4047,11 +4053,14 @@ def text_completion(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
if kwargs.get("acompletion", False) == True:
|
if kwargs.get("acompletion", False) is True:
|
||||||
return response
|
return response
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream is True or kwargs.get("stream", False) is True:
|
||||||
response = TextCompletionStreamWrapper(
|
response = TextCompletionStreamWrapper(
|
||||||
completion_stream=response, model=model, stream_options=stream_options
|
completion_stream=response,
|
||||||
|
model=model,
|
||||||
|
stream_options=stream_options,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
|
|
|
@ -58,6 +58,7 @@ from litellm.router_utils.client_initalization_utils import (
|
||||||
set_client,
|
set_client,
|
||||||
should_initialize_sync_client,
|
should_initialize_sync_client,
|
||||||
)
|
)
|
||||||
|
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
|
from litellm.router_utils.cooldown_callbacks import router_cooldown_handler
|
||||||
from litellm.router_utils.fallback_event_handlers import (
|
from litellm.router_utils.fallback_event_handlers import (
|
||||||
log_failure_fallback_event,
|
log_failure_fallback_event,
|
||||||
|
@ -90,6 +91,7 @@ from litellm.types.router import (
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
RouterErrors,
|
RouterErrors,
|
||||||
RouterGeneralSettings,
|
RouterGeneralSettings,
|
||||||
|
RouterRateLimitError,
|
||||||
updateDeployment,
|
updateDeployment,
|
||||||
updateLiteLLMParams,
|
updateLiteLLMParams,
|
||||||
)
|
)
|
||||||
|
@ -337,6 +339,9 @@ class Router:
|
||||||
else:
|
else:
|
||||||
self.allowed_fails = litellm.allowed_fails
|
self.allowed_fails = litellm.allowed_fails
|
||||||
self.cooldown_time = cooldown_time or 60
|
self.cooldown_time = cooldown_time or 60
|
||||||
|
self.cooldown_cache = CooldownCache(
|
||||||
|
cache=self.cache, default_cooldown_time=self.cooldown_time
|
||||||
|
)
|
||||||
self.disable_cooldowns = disable_cooldowns
|
self.disable_cooldowns = disable_cooldowns
|
||||||
self.failed_calls = (
|
self.failed_calls = (
|
||||||
InMemoryCache()
|
InMemoryCache()
|
||||||
|
@ -1939,6 +1944,7 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _embedding(self, input: Union[str, List], model: str, **kwargs):
|
def _embedding(self, input: Union[str, List], model: str, **kwargs):
|
||||||
|
model_name = None
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
|
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
|
||||||
|
@ -2813,19 +2819,27 @@ class Router:
|
||||||
):
|
):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
response_headers: Optional[httpx.Headers] = None
|
||||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
||||||
|
response_headers = e.response.headers
|
||||||
|
elif hasattr(e, "litellm_response_headers"):
|
||||||
|
response_headers = e.litellm_response_headers
|
||||||
|
|
||||||
|
if response_headers is not None:
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
max_retries=num_retries,
|
||||||
response_headers=e.response.headers,
|
response_headers=response_headers,
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
max_retries=num_retries,
|
max_retries=num_retries,
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
|
|
||||||
return timeout
|
return timeout
|
||||||
|
|
||||||
def function_with_retries(self, *args, **kwargs):
|
def function_with_retries(self, *args, **kwargs):
|
||||||
|
@ -2997,8 +3011,9 @@ class Router:
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
||||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||||
|
|
||||||
exception_response = getattr(exception, "response", {})
|
exception_headers = litellm.utils._get_litellm_response_headers(
|
||||||
exception_headers = getattr(exception_response, "headers", None)
|
original_exception=exception
|
||||||
|
)
|
||||||
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
|
_time_to_cooldown = kwargs.get("litellm_params", {}).get(
|
||||||
"cooldown_time", self.cooldown_time
|
"cooldown_time", self.cooldown_time
|
||||||
)
|
)
|
||||||
|
@ -3232,52 +3247,14 @@ class Router:
|
||||||
|
|
||||||
if updated_fails > allowed_fails or _should_retry is False:
|
if updated_fails > allowed_fails or _should_retry is False:
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
|
||||||
cached_value = self.cache.get_cache(
|
|
||||||
key=cooldown_key
|
|
||||||
) # [(deployment_id, {last_error_str, last_error_status_code})]
|
|
||||||
|
|
||||||
cached_value_deployment_ids = []
|
|
||||||
if (
|
|
||||||
cached_value is not None
|
|
||||||
and isinstance(cached_value, list)
|
|
||||||
and len(cached_value) > 0
|
|
||||||
and isinstance(cached_value[0], tuple)
|
|
||||||
):
|
|
||||||
cached_value_deployment_ids = [cv[0] for cv in cached_value]
|
|
||||||
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
||||||
# update value
|
# update value
|
||||||
if cached_value is not None and len(cached_value_deployment_ids) > 0:
|
self.cooldown_cache.add_deployment_to_cooldown(
|
||||||
if deployment in cached_value_deployment_ids:
|
model_id=deployment,
|
||||||
pass
|
original_exception=original_exception,
|
||||||
else:
|
exception_status=exception_status,
|
||||||
cached_value = cached_value + [
|
cooldown_time=cooldown_time,
|
||||||
(
|
)
|
||||||
deployment,
|
|
||||||
{
|
|
||||||
"Exception Received": str(original_exception),
|
|
||||||
"Status Code": str(exception_status),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
# save updated value
|
|
||||||
self.cache.set_cache(
|
|
||||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cached_value = [
|
|
||||||
(
|
|
||||||
deployment,
|
|
||||||
{
|
|
||||||
"Exception Received": str(original_exception),
|
|
||||||
"Status Code": str(exception_status),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
# save updated value
|
|
||||||
self.cache.set_cache(
|
|
||||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trigger cooldown handler
|
# Trigger cooldown handler
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -3297,15 +3274,10 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Async implementation of '_get_cooldown_deployments'
|
Async implementation of '_get_cooldown_deployments'
|
||||||
"""
|
"""
|
||||||
dt = get_utc_datetime()
|
model_ids = self.get_model_ids()
|
||||||
current_minute = dt.strftime("%H-%M")
|
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
|
||||||
# get the current cooldown list for that minute
|
model_ids=model_ids
|
||||||
cooldown_key = f"{current_minute}:cooldown_models"
|
)
|
||||||
|
|
||||||
# ----------------------
|
|
||||||
# Return cooldown models
|
|
||||||
# ----------------------
|
|
||||||
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
|
||||||
|
|
||||||
cached_value_deployment_ids = []
|
cached_value_deployment_ids = []
|
||||||
if (
|
if (
|
||||||
|
@ -3323,15 +3295,10 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Async implementation of '_get_cooldown_deployments'
|
Async implementation of '_get_cooldown_deployments'
|
||||||
"""
|
"""
|
||||||
dt = get_utc_datetime()
|
model_ids = self.get_model_ids()
|
||||||
current_minute = dt.strftime("%H-%M")
|
cooldown_models = await self.cooldown_cache.async_get_active_cooldowns(
|
||||||
# get the current cooldown list for that minute
|
model_ids=model_ids
|
||||||
cooldown_key = f"{current_minute}:cooldown_models"
|
)
|
||||||
|
|
||||||
# ----------------------
|
|
||||||
# Return cooldown models
|
|
||||||
# ----------------------
|
|
||||||
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
|
||||||
|
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cooldown_models
|
||||||
|
@ -3340,15 +3307,13 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Get the list of models being cooled down for this minute
|
Get the list of models being cooled down for this minute
|
||||||
"""
|
"""
|
||||||
dt = get_utc_datetime()
|
|
||||||
current_minute = dt.strftime("%H-%M")
|
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models"
|
|
||||||
|
|
||||||
# ----------------------
|
# ----------------------
|
||||||
# Return cooldown models
|
# Return cooldown models
|
||||||
# ----------------------
|
# ----------------------
|
||||||
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
model_ids = self.get_model_ids()
|
||||||
|
cooldown_models = self.cooldown_cache.get_active_cooldowns(model_ids=model_ids)
|
||||||
|
|
||||||
cached_value_deployment_ids = []
|
cached_value_deployment_ids = []
|
||||||
if (
|
if (
|
||||||
|
@ -3359,7 +3324,6 @@ class Router:
|
||||||
):
|
):
|
||||||
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||||
|
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
|
||||||
return cached_value_deployment_ids
|
return cached_value_deployment_ids
|
||||||
|
|
||||||
def _get_healthy_deployments(self, model: str):
|
def _get_healthy_deployments(self, model: str):
|
||||||
|
@ -4050,15 +4014,20 @@ class Router:
|
||||||
rpm_usage += t
|
rpm_usage += t
|
||||||
return tpm_usage, rpm_usage
|
return tpm_usage, rpm_usage
|
||||||
|
|
||||||
def get_model_ids(self) -> List[str]:
|
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
if 'model_name' is none, returns all.
|
||||||
|
|
||||||
Returns list of model id's.
|
Returns list of model id's.
|
||||||
"""
|
"""
|
||||||
ids = []
|
ids = []
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if "model_info" in model and "id" in model["model_info"]:
|
if "model_info" in model and "id" in model["model_info"]:
|
||||||
id = model["model_info"]["id"]
|
id = model["model_info"]["id"]
|
||||||
ids.append(id)
|
if model_name is not None and model["model_name"] == model_name:
|
||||||
|
ids.append(id)
|
||||||
|
elif model_name is None:
|
||||||
|
ids.append(id)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self) -> List[str]:
|
||||||
|
@ -4391,10 +4360,19 @@ class Router:
|
||||||
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
|
- First check for rate limit errors (if this is true, it means the model passed the context window check but failed the rate limit check)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if _rate_limit_error == True: # allow generic fallback logic to take place
|
if _rate_limit_error is True: # allow generic fallback logic to take place
|
||||||
raise ValueError(
|
model_ids = self.get_model_ids(model_name=model)
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
|
cooldown_time = self.cooldown_cache.get_min_cooldown(
|
||||||
|
model_ids=model_ids
|
||||||
)
|
)
|
||||||
|
cooldown_list = self._get_cooldown_deployments()
|
||||||
|
raise RouterRateLimitError(
|
||||||
|
model=model,
|
||||||
|
cooldown_time=cooldown_time,
|
||||||
|
enable_pre_call_checks=True,
|
||||||
|
cooldown_list=cooldown_list,
|
||||||
|
)
|
||||||
|
|
||||||
elif _context_window_error is True:
|
elif _context_window_error is True:
|
||||||
raise litellm.ContextWindowExceededError(
|
raise litellm.ContextWindowExceededError(
|
||||||
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
|
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
|
||||||
|
@ -4503,8 +4481,14 @@ class Router:
|
||||||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
raise ValueError(
|
model_ids = self.get_model_ids(model_name=model)
|
||||||
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
|
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
|
||||||
|
_cooldown_list = self._get_cooldown_deployments()
|
||||||
|
raise RouterRateLimitError(
|
||||||
|
model=model,
|
||||||
|
cooldown_time=_cooldown_time,
|
||||||
|
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||||
|
cooldown_list=_cooldown_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
@ -4591,8 +4575,16 @@ class Router:
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
if _allowed_model_region is None:
|
if _allowed_model_region is None:
|
||||||
_allowed_model_region = "n/a"
|
_allowed_model_region = "n/a"
|
||||||
raise ValueError(
|
model_ids = self.get_model_ids(model_name=model)
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
|
_cooldown_time = self.cooldown_cache.get_min_cooldown(
|
||||||
|
model_ids=model_ids
|
||||||
|
)
|
||||||
|
_cooldown_list = self._get_cooldown_deployments()
|
||||||
|
raise RouterRateLimitError(
|
||||||
|
model=model,
|
||||||
|
cooldown_time=_cooldown_time,
|
||||||
|
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||||
|
cooldown_list=_cooldown_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -4671,8 +4663,16 @@ class Router:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
model_ids = self.get_model_ids(model_name=model)
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
_cooldown_time = self.cooldown_cache.get_min_cooldown(
|
||||||
|
model_ids=model_ids
|
||||||
|
)
|
||||||
|
_cooldown_list = self._get_cooldown_deployments()
|
||||||
|
raise RouterRateLimitError(
|
||||||
|
model=model,
|
||||||
|
cooldown_time=_cooldown_time,
|
||||||
|
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||||
|
cooldown_list=_cooldown_list,
|
||||||
)
|
)
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||||
|
@ -4744,8 +4744,14 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
raise ValueError(
|
model_ids = self.get_model_ids(model_name=model)
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, cooldown_list={self._get_cooldown_deployments()}"
|
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
|
||||||
|
_cooldown_list = self._get_cooldown_deployments()
|
||||||
|
raise RouterRateLimitError(
|
||||||
|
model=model,
|
||||||
|
cooldown_time=_cooldown_time,
|
||||||
|
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||||
|
cooldown_list=_cooldown_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None:
|
||||||
|
@ -4825,8 +4831,14 @@ class Router:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
model_ids = self.get_model_ids(model_name=model)
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
|
||||||
|
_cooldown_list = self._get_cooldown_deployments()
|
||||||
|
raise RouterRateLimitError(
|
||||||
|
model=model,
|
||||||
|
cooldown_time=_cooldown_time,
|
||||||
|
enable_pre_call_checks=self.enable_pre_call_checks,
|
||||||
|
cooldown_list=_cooldown_list,
|
||||||
)
|
)
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||||
|
|
138
litellm/router_utils/cooldown_cache.py
Normal file
138
litellm/router_utils/cooldown_cache.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
"""
|
||||||
|
Wrapper around router cache. Meant to handle model cooldown logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
|
class CooldownCacheValue(TypedDict):
|
||||||
|
exception_received: str
|
||||||
|
status_code: str
|
||||||
|
timestamp: float
|
||||||
|
cooldown_time: float
|
||||||
|
|
||||||
|
|
||||||
|
class CooldownCache:
|
||||||
|
def __init__(self, cache: DualCache, default_cooldown_time: float):
|
||||||
|
self.cache = cache
|
||||||
|
self.default_cooldown_time = default_cooldown_time
|
||||||
|
|
||||||
|
def _common_add_cooldown_logic(
|
||||||
|
self, model_id: str, original_exception, exception_status, cooldown_time: float
|
||||||
|
) -> Tuple[str, CooldownCacheValue]:
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
cooldown_key = f"deployment:{model_id}:cooldown"
|
||||||
|
|
||||||
|
# Store the cooldown information for the deployment separately
|
||||||
|
cooldown_data = CooldownCacheValue(
|
||||||
|
exception_received=str(original_exception),
|
||||||
|
status_code=str(exception_status),
|
||||||
|
timestamp=current_time,
|
||||||
|
cooldown_time=cooldown_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cooldown_key, cooldown_data
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def add_deployment_to_cooldown(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
original_exception: Exception,
|
||||||
|
exception_status: int,
|
||||||
|
cooldown_time: Optional[float],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
_cooldown_time = cooldown_time or self.default_cooldown_time
|
||||||
|
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
|
||||||
|
model_id=model_id,
|
||||||
|
original_exception=original_exception,
|
||||||
|
exception_status=exception_status,
|
||||||
|
cooldown_time=_cooldown_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the cache with a TTL equal to the cooldown time
|
||||||
|
self.cache.set_cache(
|
||||||
|
value=cooldown_data,
|
||||||
|
key=cooldown_key,
|
||||||
|
ttl=_cooldown_time,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def async_get_active_cooldowns(
|
||||||
|
self, model_ids: List[str]
|
||||||
|
) -> List[Tuple[str, CooldownCacheValue]]:
|
||||||
|
# Generate the keys for the deployments
|
||||||
|
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||||
|
|
||||||
|
# Retrieve the values for the keys using mget
|
||||||
|
results = await self.cache.async_batch_get_cache(keys=keys)
|
||||||
|
|
||||||
|
active_cooldowns = []
|
||||||
|
# Process the results
|
||||||
|
for model_id, result in zip(model_ids, results):
|
||||||
|
if result and isinstance(result, dict):
|
||||||
|
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||||
|
active_cooldowns.append((model_id, cooldown_cache_value))
|
||||||
|
|
||||||
|
return active_cooldowns
|
||||||
|
|
||||||
|
def get_active_cooldowns(
|
||||||
|
self, model_ids: List[str]
|
||||||
|
) -> List[Tuple[str, CooldownCacheValue]]:
|
||||||
|
# Generate the keys for the deployments
|
||||||
|
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||||
|
|
||||||
|
# Retrieve the values for the keys using mget
|
||||||
|
results = self.cache.batch_get_cache(keys=keys)
|
||||||
|
|
||||||
|
active_cooldowns = []
|
||||||
|
# Process the results
|
||||||
|
for model_id, result in zip(model_ids, results):
|
||||||
|
if result and isinstance(result, dict):
|
||||||
|
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||||
|
active_cooldowns.append((model_id, cooldown_cache_value))
|
||||||
|
|
||||||
|
return active_cooldowns
|
||||||
|
|
||||||
|
def get_min_cooldown(self, model_ids: List[str]) -> float:
|
||||||
|
"""Return min cooldown time required for a group of model id's."""
|
||||||
|
|
||||||
|
# Generate the keys for the deployments
|
||||||
|
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||||
|
|
||||||
|
# Retrieve the values for the keys using mget
|
||||||
|
results = self.cache.batch_get_cache(keys=keys)
|
||||||
|
|
||||||
|
min_cooldown_time = self.default_cooldown_time
|
||||||
|
# Process the results
|
||||||
|
for model_id, result in zip(model_ids, results):
|
||||||
|
if result and isinstance(result, dict):
|
||||||
|
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||||
|
if cooldown_cache_value["cooldown_time"] < min_cooldown_time:
|
||||||
|
min_cooldown_time = cooldown_cache_value["cooldown_time"]
|
||||||
|
|
||||||
|
return min_cooldown_time
|
||||||
|
|
||||||
|
|
||||||
|
# Usage example:
|
||||||
|
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
|
||||||
|
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
|
||||||
|
# active_cooldowns = cooldown_cache.get_active_cooldowns()
|
|
@ -1635,18 +1635,19 @@ def test_completion_perplexity_api():
|
||||||
pydantic_obj = ChatCompletion(**response_object)
|
pydantic_obj = ChatCompletion(**response_object)
|
||||||
|
|
||||||
def _return_pydantic_obj(*args, **kwargs):
|
def _return_pydantic_obj(*args, **kwargs):
|
||||||
return pydantic_obj
|
new_response = MagicMock()
|
||||||
|
new_response.headers = {"hello": "world"}
|
||||||
|
|
||||||
print(f"pydantic_obj: {pydantic_obj}")
|
new_response.parse.return_value = pydantic_obj
|
||||||
|
return new_response
|
||||||
|
|
||||||
openai_client = OpenAI()
|
openai_client = OpenAI()
|
||||||
|
|
||||||
openai_client.chat.completions.create = MagicMock()
|
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
openai_client.chat.completions, "create", side_effect=_return_pydantic_obj
|
openai_client.chat.completions.with_raw_response,
|
||||||
|
"create",
|
||||||
|
side_effect=_return_pydantic_obj,
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
pass
|
|
||||||
# litellm.set_verbose= True
|
# litellm.set_verbose= True
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You're a good bot"},
|
{"role": "system", "content": "You're a good bot"},
|
||||||
|
|
|
@ -839,3 +839,138 @@ def test_anthropic_tool_calling_exception():
|
||||||
)
|
)
|
||||||
except litellm.BadRequestError:
|
except litellm.BadRequestError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI, OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def _pre_call_utils(
|
||||||
|
call_type: str,
|
||||||
|
data: dict,
|
||||||
|
client: Union[OpenAI, AsyncOpenAI],
|
||||||
|
sync_mode: bool,
|
||||||
|
streaming: Optional[bool],
|
||||||
|
):
|
||||||
|
if call_type == "embedding":
|
||||||
|
data["input"] = "Hello world!"
|
||||||
|
mapped_target = client.embeddings.with_raw_response
|
||||||
|
if sync_mode:
|
||||||
|
original_function = litellm.embedding
|
||||||
|
else:
|
||||||
|
original_function = litellm.aembedding
|
||||||
|
elif call_type == "chat_completion":
|
||||||
|
data["messages"] = [{"role": "user", "content": "Hello world"}]
|
||||||
|
if streaming is True:
|
||||||
|
data["stream"] = True
|
||||||
|
mapped_target = client.chat.completions.with_raw_response
|
||||||
|
if sync_mode:
|
||||||
|
original_function = litellm.completion
|
||||||
|
else:
|
||||||
|
original_function = litellm.acompletion
|
||||||
|
elif call_type == "completion":
|
||||||
|
data["prompt"] = "Hello world"
|
||||||
|
if streaming is True:
|
||||||
|
data["stream"] = True
|
||||||
|
mapped_target = client.completions.with_raw_response
|
||||||
|
if sync_mode:
|
||||||
|
original_function = litellm.text_completion
|
||||||
|
else:
|
||||||
|
original_function = litellm.atext_completion
|
||||||
|
|
||||||
|
return data, original_function, mapped_target
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sync_mode",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider, model, call_type, streaming",
|
||||||
|
[
|
||||||
|
("openai", "text-embedding-ada-002", "embedding", None),
|
||||||
|
("openai", "gpt-3.5-turbo", "chat_completion", False),
|
||||||
|
("openai", "gpt-3.5-turbo", "chat_completion", True),
|
||||||
|
("openai", "gpt-3.5-turbo-instruct", "completion", True),
|
||||||
|
("azure", "azure/chatgpt-v-2", "chat_completion", True),
|
||||||
|
("azure", "azure/text-embedding-ada-002", "embedding", True),
|
||||||
|
("azure", "azure_text/gpt-3.5-turbo-instruct", "completion", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exception_with_headers(sync_mode, provider, model, call_type, streaming):
|
||||||
|
"""
|
||||||
|
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
|
||||||
|
but Azure says to retry in at most 9s
|
||||||
|
|
||||||
|
```
|
||||||
|
{"message": "litellm.proxy.proxy_server.embeddings(): Exception occured - No deployments available for selected model, Try again in 60 seconds. Passed model=text-embedding-ada-002. pre-call-checks=False, allowed_model_region=n/a, cooldown_list=[('b49cbc9314273db7181fe69b1b19993f04efb88f2c1819947c538bac08097e4c', {'Exception Received': 'litellm.RateLimitError: AzureException RateLimitError - Requests to the Embeddings_Create Operation under Azure OpenAI API version 2023-09-01-preview have exceeded call rate limit of your current OpenAI S0 pricing tier. Please retry after 9 seconds. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit.', 'Status Code': '429'})]", "level": "ERROR", "timestamp": "2024-08-22T03:25:36.900476"}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
import openai
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
if provider == "openai":
|
||||||
|
openai_client = openai.OpenAI(api_key="")
|
||||||
|
elif provider == "azure":
|
||||||
|
openai_client = openai.AzureOpenAI(api_key="", base_url="")
|
||||||
|
else:
|
||||||
|
if provider == "openai":
|
||||||
|
openai_client = openai.AsyncOpenAI(api_key="")
|
||||||
|
elif provider == "azure":
|
||||||
|
openai_client = openai.AsyncAzureOpenAI(api_key="", base_url="")
|
||||||
|
|
||||||
|
data = {"model": model}
|
||||||
|
data, original_function, mapped_target = _pre_call_utils(
|
||||||
|
call_type=call_type,
|
||||||
|
data=data,
|
||||||
|
client=openai_client,
|
||||||
|
sync_mode=sync_mode,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
cooldown_time = 30.0
|
||||||
|
|
||||||
|
def _return_exception(*args, **kwargs):
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Rate Limited!",
|
||||||
|
headers={"retry-after": cooldown_time}, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
mapped_target,
|
||||||
|
"create",
|
||||||
|
side_effect=_return_exception,
|
||||||
|
):
|
||||||
|
new_retry_after_mock_client = MagicMock(return_value=-1)
|
||||||
|
|
||||||
|
litellm.utils._get_retry_after_from_exception_header = (
|
||||||
|
new_retry_after_mock_client
|
||||||
|
)
|
||||||
|
|
||||||
|
exception_raised = False
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
resp = original_function(**data, client=openai_client)
|
||||||
|
if streaming:
|
||||||
|
for chunk in resp:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
resp = await original_function(**data, client=openai_client)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
async for chunk in resp:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
exception_raised = True
|
||||||
|
assert e.litellm_response_headers is not None
|
||||||
|
assert e.litellm_response_headers["retry-after"] == cooldown_time
|
||||||
|
|
||||||
|
if exception_raised is False:
|
||||||
|
print(resp)
|
||||||
|
assert exception_raised
|
||||||
|
|
|
@ -10,6 +10,9 @@ import traceback
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import litellm.types
|
||||||
|
import litellm.types.router
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
@ -2184,3 +2187,158 @@ def test_router_correctly_reraise_error():
|
||||||
)
|
)
|
||||||
except litellm.RateLimitError:
|
except litellm.RateLimitError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_dynamic_cooldown_correct_retry_after_time(sync_mode):
|
||||||
|
"""
|
||||||
|
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
|
||||||
|
but Azure says to retry in at most 9s
|
||||||
|
|
||||||
|
```
|
||||||
|
{"message": "litellm.proxy.proxy_server.embeddings(): Exception occured - No deployments available for selected model, Try again in 60 seconds. Passed model=text-embedding-ada-002. pre-call-checks=False, allowed_model_region=n/a, cooldown_list=[('b49cbc9314273db7181fe69b1b19993f04efb88f2c1819947c538bac08097e4c', {'Exception Received': 'litellm.RateLimitError: AzureException RateLimitError - Requests to the Embeddings_Create Operation under Azure OpenAI API version 2023-09-01-preview have exceeded call rate limit of your current OpenAI S0 pricing tier. Please retry after 9 seconds. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit.', 'Status Code': '429'})]", "level": "ERROR", "timestamp": "2024-08-22T03:25:36.900476"}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "text-embedding-ada-002",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_client = openai.OpenAI(api_key="")
|
||||||
|
|
||||||
|
cooldown_time = 30.0
|
||||||
|
|
||||||
|
def _return_exception(*args, **kwargs):
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Rate Limited!",
|
||||||
|
headers={"retry-after": cooldown_time}, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.embeddings.with_raw_response,
|
||||||
|
"create",
|
||||||
|
side_effect=_return_exception,
|
||||||
|
):
|
||||||
|
new_retry_after_mock_client = MagicMock(return_value=-1)
|
||||||
|
|
||||||
|
litellm.utils._get_retry_after_from_exception_header = (
|
||||||
|
new_retry_after_mock_client
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
router.embedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input="Hello world!",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
new_retry_after_mock_client.assert_called()
|
||||||
|
print(
|
||||||
|
f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response_headers: httpx.Headers = new_retry_after_mock_client.call_args.kwargs[
|
||||||
|
"response_headers"
|
||||||
|
]
|
||||||
|
assert "retry-after" in response_headers
|
||||||
|
assert response_headers["retry-after"] == cooldown_time
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_dynamic_cooldown_message_retry_time(sync_mode):
|
||||||
|
"""
|
||||||
|
User feedback: litellm says "No deployments available for selected model, Try again in 60 seconds"
|
||||||
|
but Azure says to retry in at most 9s
|
||||||
|
|
||||||
|
```
|
||||||
|
{"message": "litellm.proxy.proxy_server.embeddings(): Exception occured - No deployments available for selected model, Try again in 60 seconds. Passed model=text-embedding-ada-002. pre-call-checks=False, allowed_model_region=n/a, cooldown_list=[('b49cbc9314273db7181fe69b1b19993f04efb88f2c1819947c538bac08097e4c', {'Exception Received': 'litellm.RateLimitError: AzureException RateLimitError - Requests to the Embeddings_Create Operation under Azure OpenAI API version 2023-09-01-preview have exceeded call rate limit of your current OpenAI S0 pricing tier. Please retry after 9 seconds. Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit.', 'Status Code': '429'})]", "level": "ERROR", "timestamp": "2024-08-22T03:25:36.900476"}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "text-embedding-ada-002",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_client = openai.OpenAI(api_key="")
|
||||||
|
|
||||||
|
cooldown_time = 30.0
|
||||||
|
|
||||||
|
def _return_exception(*args, **kwargs):
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Rate Limited!",
|
||||||
|
headers={"retry-after": cooldown_time},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.embeddings.with_raw_response,
|
||||||
|
"create",
|
||||||
|
side_effect=_return_exception,
|
||||||
|
):
|
||||||
|
for _ in range(2):
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
router.embedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input="Hello world!",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await router.aembedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input="Hello world!",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
cooldown_deployments = router._get_cooldown_deployments()
|
||||||
|
else:
|
||||||
|
cooldown_deployments = await router._async_get_cooldown_deployments()
|
||||||
|
print(
|
||||||
|
"Cooldown deployments - {}\n{}".format(
|
||||||
|
cooldown_deployments, len(cooldown_deployments)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(cooldown_deployments) > 0
|
||||||
|
exception_raised = False
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
router.embedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input="Hello world!",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await router.aembedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input="Hello world!",
|
||||||
|
client=openai_client,
|
||||||
|
)
|
||||||
|
except litellm.types.router.RouterRateLimitError as e:
|
||||||
|
print(e)
|
||||||
|
exception_raised = True
|
||||||
|
assert e.cooldown_time == cooldown_time
|
||||||
|
|
||||||
|
assert exception_raised
|
||||||
|
|
|
@ -549,3 +549,19 @@ class RouterGeneralSettings(BaseModel):
|
||||||
pass_through_all_models: bool = Field(
|
pass_through_all_models: bool = Field(
|
||||||
default=False
|
default=False
|
||||||
) # if passed a model not llm_router model list, pass through the request to litellm.acompletion/embedding
|
) # if passed a model not llm_router model list, pass through the request to litellm.acompletion/embedding
|
||||||
|
|
||||||
|
|
||||||
|
class RouterRateLimitError(ValueError):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
cooldown_time: float,
|
||||||
|
enable_pre_call_checks: bool,
|
||||||
|
cooldown_list: List,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.cooldown_time = cooldown_time
|
||||||
|
self.enable_pre_call_checks = enable_pre_call_checks
|
||||||
|
self.cooldown_list = cooldown_list
|
||||||
|
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
|
||||||
|
super().__init__(_message)
|
||||||
|
|
|
@ -638,7 +638,10 @@ def client(original_function):
|
||||||
if is_coroutine is True:
|
if is_coroutine is True:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(original_response, ModelResponse):
|
if (
|
||||||
|
isinstance(original_response, ModelResponse)
|
||||||
|
and len(original_response.choices) > 0
|
||||||
|
):
|
||||||
model_response: Optional[str] = original_response.choices[
|
model_response: Optional[str] = original_response.choices[
|
||||||
0
|
0
|
||||||
].message.content # type: ignore
|
].message.content # type: ignore
|
||||||
|
@ -6382,6 +6385,7 @@ def _get_retry_after_from_exception_header(
|
||||||
retry_after = int(retry_date - time.time())
|
retry_after = int(retry_date - time.time())
|
||||||
else:
|
else:
|
||||||
retry_after = -1
|
retry_after = -1
|
||||||
|
|
||||||
return retry_after
|
return retry_after
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -6563,6 +6567,40 @@ def get_model_list():
|
||||||
|
|
||||||
|
|
||||||
####### EXCEPTION MAPPING ################
|
####### EXCEPTION MAPPING ################
|
||||||
|
def _get_litellm_response_headers(
|
||||||
|
original_exception: Exception,
|
||||||
|
) -> Optional[httpx.Headers]:
|
||||||
|
"""
|
||||||
|
Extract and return the response headers from a mapped exception, if present.
|
||||||
|
|
||||||
|
Used for accurate retry logic.
|
||||||
|
"""
|
||||||
|
_response_headers: Optional[httpx.Headers] = None
|
||||||
|
try:
|
||||||
|
_response_headers = getattr(
|
||||||
|
original_exception, "litellm_response_headers", None
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _response_headers
|
||||||
|
|
||||||
|
|
||||||
|
def _get_response_headers(original_exception: Exception) -> Optional[httpx.Headers]:
|
||||||
|
"""
|
||||||
|
Extract and return the response headers from an exception, if present.
|
||||||
|
|
||||||
|
Used for accurate retry logic.
|
||||||
|
"""
|
||||||
|
_response_headers: Optional[httpx.Headers] = None
|
||||||
|
try:
|
||||||
|
_response_headers = getattr(original_exception, "headers", None)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _response_headers
|
||||||
|
|
||||||
|
|
||||||
def exception_type(
|
def exception_type(
|
||||||
model,
|
model,
|
||||||
original_exception,
|
original_exception,
|
||||||
|
@ -6587,6 +6625,10 @@ def exception_type(
|
||||||
"LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa
|
"LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa
|
||||||
) # noqa
|
) # noqa
|
||||||
print() # noqa
|
print() # noqa
|
||||||
|
|
||||||
|
litellm_response_headers = _get_response_headers(
|
||||||
|
original_exception=original_exception
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if model:
|
if model:
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
|
@ -6841,7 +6883,7 @@ def exception_type(
|
||||||
message=f"{exception_provider} - {message}",
|
message=f"{exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
|
@ -6850,7 +6892,7 @@ def exception_type(
|
||||||
message=f"RateLimitError: {exception_provider} - {message}",
|
message=f"RateLimitError: {exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
|
@ -6859,7 +6901,7 @@ def exception_type(
|
||||||
message=f"ServiceUnavailableError: {exception_provider} - {message}",
|
message=f"ServiceUnavailableError: {exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
@ -6877,7 +6919,7 @@ def exception_type(
|
||||||
message=f"APIError: {exception_provider} - {message}",
|
message=f"APIError: {exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=getattr(original_exception, "request", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -8165,7 +8207,7 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||||
message = get_error_message(error_obj=original_exception)
|
message = get_error_message(error_obj=original_exception)
|
||||||
if message is None:
|
if message is None:
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
|
@ -8469,20 +8511,20 @@ def exception_type(
|
||||||
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
|
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
|
||||||
# don't let an error with mapping interrupt the user from receiving an error from the llm api calls
|
# don't let an error with mapping interrupt the user from receiving an error from the llm api calls
|
||||||
if exception_mapping_worked:
|
if exception_mapping_worked:
|
||||||
|
setattr(e, "litellm_response_headers", litellm_response_headers)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
for error_type in litellm.LITELLM_EXCEPTION_TYPES:
|
for error_type in litellm.LITELLM_EXCEPTION_TYPES:
|
||||||
if isinstance(e, error_type):
|
if isinstance(e, error_type):
|
||||||
|
setattr(e, "litellm_response_headers", litellm_response_headers)
|
||||||
raise e # it's already mapped
|
raise e # it's already mapped
|
||||||
raise APIConnectionError(
|
raised_exc = APIConnectionError(
|
||||||
message="{}\n{}".format(original_exception, traceback.format_exc()),
|
message="{}\n{}".format(original_exception, traceback.format_exc()),
|
||||||
llm_provider="",
|
llm_provider="",
|
||||||
model="",
|
model="",
|
||||||
request=httpx.Request(
|
|
||||||
method="POST",
|
|
||||||
url="https://www.litellm.ai/",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
setattr(raised_exc, "litellm_response_headers", _response_headers)
|
||||||
|
raise raised_exc
|
||||||
|
|
||||||
|
|
||||||
######### Secret Manager ############################
|
######### Secret Manager ############################
|
||||||
|
@ -10916,10 +10958,17 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
|
|
||||||
class TextCompletionStreamWrapper:
|
class TextCompletionStreamWrapper:
|
||||||
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
completion_stream,
|
||||||
|
model,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
):
|
||||||
self.completion_stream = completion_stream
|
self.completion_stream = completion_stream
|
||||||
self.model = model
|
self.model = model
|
||||||
self.stream_options = stream_options
|
self.stream_options = stream_options
|
||||||
|
self.custom_llm_provider = custom_llm_provider
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -10970,7 +11019,13 @@ class TextCompletionStreamWrapper:
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"got exception {e}") # noqa
|
raise exception_type(
|
||||||
|
model=self.model,
|
||||||
|
custom_llm_provider=self.custom_llm_provider or "",
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue