LiteLLM Minor Fixes & Improvements (12/23/2024) - P2 (#7386)

* fix(main.py): support 'mock_timeout=true' param

allows mock requests on proxy to have a time delay, for testing

* fix(main.py): ensure mock timeouts raise litellm.Timeout error

triggers retry/fallbacks

* fix: fix fallback + mock timeout testing

* fix(router.py): always return remaining tpm/rpm limits, if limits are known

allows for rate limit headers to be guaranteed

* docs(timeout.md): add docs on mock timeout = true

* fix(main.py): fix linting errors

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-23 17:41:27 -08:00 committed by GitHub
parent db59e08958
commit 48316520f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 223 additions and 54 deletions

View file

@ -527,6 +527,73 @@ async def _async_streaming(response, model, custom_llm_provider, args):
)
def _handle_mock_potential_exceptions(
mock_response: Union[str, Exception, dict],
model: str,
custom_llm_provider: Optional[str] = None,
):
if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif (
isinstance(mock_response, str)
and mock_response == "litellm.InternalServerError"
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
"Exception: content_filter_policy"
):
raise litellm.MockException(
status_code=400,
message=mock_response,
llm_provider="azure",
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
def _handle_mock_timeout(
mock_timeout: Optional[bool],
timeout: Optional[Union[float, str, httpx.Timeout]],
model: str,
):
if mock_timeout is True and timeout is not None:
if isinstance(timeout, float):
time.sleep(timeout)
elif isinstance(timeout, str):
time.sleep(float(timeout))
elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None:
time.sleep(timeout.connect)
raise litellm.Timeout(
message="This is a mock timeout error",
llm_provider="openai",
model=model,
)
def mock_completion(
model: str,
messages: List,
@ -534,8 +601,10 @@ def mock_completion(
n: Optional[int] = None,
mock_response: Union[str, Exception, dict] = "This is a mock request",
mock_tool_calls: Optional[List] = None,
mock_timeout: Optional[bool] = False,
logging=None,
custom_llm_provider=None,
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
**kwargs,
):
"""
@ -548,6 +617,8 @@ def mock_completion(
messages (List): A list of message objects representing the conversation context.
stream (bool, optional): If True, returns a mock streaming response (default is False).
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
mock_timeout (bool, optional): If True, the mock response will be a timeout error (default is False).
timeout (float, optional): The timeout value to use for the mock response (default is None).
**kwargs: Additional keyword arguments that can be used but are not required.
Returns:
@ -560,56 +631,28 @@ def mock_completion(
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
"""
try:
if mock_response is None:
mock_response = "This is a mock request"
_handle_mock_timeout(mock_timeout=mock_timeout, timeout=timeout, model=model)
## LOGGING
if logging is not None:
logging.pre_call(
input=messages,
api_key="mock-key",
)
if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif (
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
):
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif (
isinstance(mock_response, str)
and mock_response == "litellm.InternalServerError"
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
"Exception: content_filter_policy"
):
raise litellm.MockException(
status_code=400,
message=mock_response,
llm_provider="azure",
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif isinstance(mock_response, str) and mock_response.startswith(
_handle_mock_potential_exceptions(
mock_response=mock_response,
model=model,
custom_llm_provider=custom_llm_provider,
)
mock_response = cast(
Union[str, dict], mock_response
) # after this point, mock_response is a string or dict
if isinstance(mock_response, str) and mock_response.startswith(
"Exception: mock_streaming_error"
):
mock_response = litellm.MockException(
@ -788,6 +831,7 @@ def completion( # type: ignore # noqa: PLR0915
api_base = kwargs.get("api_base", None)
mock_response = kwargs.get("mock_response", None)
mock_tool_calls = kwargs.get("mock_tool_calls", None)
mock_timeout = cast(Optional[bool], kwargs.get("mock_timeout", None))
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
logger_fn = kwargs.get("logger_fn", None)
verbose = kwargs.get("verbose", False)
@ -1102,7 +1146,8 @@ def completion( # type: ignore # noqa: PLR0915
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if mock_response or mock_tool_calls:
if mock_response or mock_tool_calls or mock_timeout:
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
return mock_completion(
model,
messages,
@ -1114,6 +1159,8 @@ def completion( # type: ignore # noqa: PLR0915
acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider,
mock_timeout=mock_timeout,
timeout=timeout,
)
if custom_llm_provider == "azure":