mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
LiteLLM Minor Fixes & Improvements (12/23/2024) - P2 (#7386)
* fix(main.py): support 'mock_timeout=true' param allows mock requests on proxy to have a time delay, for testing * fix(main.py): ensure mock timeouts raise litellm.Timeout error triggers retry/fallbacks * fix: fix fallback + mock timeout testing * fix(router.py): always return remaining tpm/rpm limits, if limits are known allows for rate limit headers to be guaranteed * docs(timeout.md): add docs on mock timeout = true * fix(main.py): fix linting errors * test: fix test
This commit is contained in:
parent
a89b0d5c39
commit
51f9f75c85
7 changed files with 223 additions and 54 deletions
137
litellm/main.py
137
litellm/main.py
|
@ -527,6 +527,73 @@ async def _async_streaming(response, model, custom_llm_provider, args):
|
|||
)
|
||||
|
||||
|
||||
def _handle_mock_potential_exceptions(
|
||||
mock_response: Union[str, Exception, dict],
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
if isinstance(mock_response, Exception):
|
||||
if isinstance(mock_response, openai.APIError):
|
||||
raise mock_response
|
||||
raise litellm.MockException(
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
|
||||
raise litellm.RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
isinstance(mock_response, str)
|
||||
and mock_response == "litellm.InternalServerError"
|
||||
):
|
||||
raise litellm.InternalServerError(
|
||||
message="this is a mock internal server error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
"Exception: content_filter_policy"
|
||||
):
|
||||
raise litellm.MockException(
|
||||
status_code=400,
|
||||
message=mock_response,
|
||||
llm_provider="azure",
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
|
||||
|
||||
def _handle_mock_timeout(
|
||||
mock_timeout: Optional[bool],
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]],
|
||||
model: str,
|
||||
):
|
||||
if mock_timeout is True and timeout is not None:
|
||||
if isinstance(timeout, float):
|
||||
time.sleep(timeout)
|
||||
elif isinstance(timeout, str):
|
||||
time.sleep(float(timeout))
|
||||
elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None:
|
||||
time.sleep(timeout.connect)
|
||||
raise litellm.Timeout(
|
||||
message="This is a mock timeout error",
|
||||
llm_provider="openai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def mock_completion(
|
||||
model: str,
|
||||
messages: List,
|
||||
|
@ -534,8 +601,10 @@ def mock_completion(
|
|||
n: Optional[int] = None,
|
||||
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||
mock_tool_calls: Optional[List] = None,
|
||||
mock_timeout: Optional[bool] = False,
|
||||
logging=None,
|
||||
custom_llm_provider=None,
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -548,6 +617,8 @@ def mock_completion(
|
|||
messages (List): A list of message objects representing the conversation context.
|
||||
stream (bool, optional): If True, returns a mock streaming response (default is False).
|
||||
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
|
||||
mock_timeout (bool, optional): If True, the mock response will be a timeout error (default is False).
|
||||
timeout (float, optional): The timeout value to use for the mock response (default is None).
|
||||
**kwargs: Additional keyword arguments that can be used but are not required.
|
||||
|
||||
Returns:
|
||||
|
@ -560,56 +631,28 @@ def mock_completion(
|
|||
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
||||
"""
|
||||
try:
|
||||
if mock_response is None:
|
||||
mock_response = "This is a mock request"
|
||||
|
||||
_handle_mock_timeout(mock_timeout=mock_timeout, timeout=timeout, model=model)
|
||||
|
||||
## LOGGING
|
||||
if logging is not None:
|
||||
logging.pre_call(
|
||||
input=messages,
|
||||
api_key="mock-key",
|
||||
)
|
||||
if isinstance(mock_response, Exception):
|
||||
if isinstance(mock_response, openai.APIError):
|
||||
raise mock_response
|
||||
raise litellm.MockException(
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif (
|
||||
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
|
||||
):
|
||||
raise litellm.RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
isinstance(mock_response, str)
|
||||
and mock_response == "litellm.InternalServerError"
|
||||
):
|
||||
raise litellm.InternalServerError(
|
||||
message="this is a mock internal server error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
"Exception: content_filter_policy"
|
||||
):
|
||||
raise litellm.MockException(
|
||||
status_code=400,
|
||||
message=mock_response,
|
||||
llm_provider="azure",
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
|
||||
_handle_mock_potential_exceptions(
|
||||
mock_response=mock_response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
mock_response = cast(
|
||||
Union[str, dict], mock_response
|
||||
) # after this point, mock_response is a string or dict
|
||||
if isinstance(mock_response, str) and mock_response.startswith(
|
||||
"Exception: mock_streaming_error"
|
||||
):
|
||||
mock_response = litellm.MockException(
|
||||
|
@ -788,6 +831,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_base = kwargs.get("api_base", None)
|
||||
mock_response = kwargs.get("mock_response", None)
|
||||
mock_tool_calls = kwargs.get("mock_tool_calls", None)
|
||||
mock_timeout = cast(Optional[bool], kwargs.get("mock_timeout", None))
|
||||
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
@ -1102,7 +1146,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if mock_response or mock_tool_calls:
|
||||
if mock_response or mock_tool_calls or mock_timeout:
|
||||
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
|
||||
return mock_completion(
|
||||
model,
|
||||
messages,
|
||||
|
@ -1114,6 +1159,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
acompletion=acompletion,
|
||||
mock_delay=kwargs.get("mock_delay", None),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
mock_timeout=mock_timeout,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue