LiteLLM Minor Fixes & Improvements (12/23/2024) - P2 (#7386)

* fix(main.py): support 'mock_timeout=true' param

allows mock requests on proxy to have a time delay, for testing

* fix(main.py): ensure mock timeouts raise litellm.Timeout error

triggers retry/fallbacks

* fix: fix fallback + mock timeout testing

* fix(router.py): always return remaining tpm/rpm limits, if limits are known

allows for rate limit headers to be guaranteed

* docs(timeout.md): add docs on mock timeout = true

* fix(main.py): fix linting errors

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-23 17:41:27 -08:00 committed by GitHub
parent a89b0d5c39
commit 51f9f75c85
7 changed files with 223 additions and 54 deletions

View file

@ -175,4 +175,24 @@ print(response)
</Tabs>
</TabItem>
</Tabs>
</Tabs>
## Testing timeout handling
To test if your retry/fallback logic can handle timeouts, you can set `mock_timeout=True` for testing.
This is currently only supported on `/chat/completions` and `/completions` endpoints. Please [let us know](https://github.com/BerriAI/litellm/issues) if you need this for other endpoints.
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
--data-raw '{
"model": "gemini/gemini-1.5-flash",
"messages": [
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
],
"mock_timeout": true # 👈 KEY CHANGE
}'
```

View file

@ -527,6 +527,73 @@ async def _async_streaming(response, model, custom_llm_provider, args):
)
def _handle_mock_potential_exceptions(
mock_response: Union[str, Exception, dict],
model: str,
custom_llm_provider: Optional[str] = None,
):
if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif (
isinstance(mock_response, str)
and mock_response == "litellm.InternalServerError"
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
"Exception: content_filter_policy"
):
raise litellm.MockException(
status_code=400,
message=mock_response,
llm_provider="azure",
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
def _handle_mock_timeout(
mock_timeout: Optional[bool],
timeout: Optional[Union[float, str, httpx.Timeout]],
model: str,
):
if mock_timeout is True and timeout is not None:
if isinstance(timeout, float):
time.sleep(timeout)
elif isinstance(timeout, str):
time.sleep(float(timeout))
elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None:
time.sleep(timeout.connect)
raise litellm.Timeout(
message="This is a mock timeout error",
llm_provider="openai",
model=model,
)
def mock_completion(
model: str,
messages: List,
@ -534,8 +601,10 @@ def mock_completion(
n: Optional[int] = None,
mock_response: Union[str, Exception, dict] = "This is a mock request",
mock_tool_calls: Optional[List] = None,
mock_timeout: Optional[bool] = False,
logging=None,
custom_llm_provider=None,
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
**kwargs,
):
"""
@ -548,6 +617,8 @@ def mock_completion(
messages (List): A list of message objects representing the conversation context.
stream (bool, optional): If True, returns a mock streaming response (default is False).
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
mock_timeout (bool, optional): If True, the mock response will be a timeout error (default is False).
timeout (float, optional): The timeout value to use for the mock response (default is None).
**kwargs: Additional keyword arguments that can be used but are not required.
Returns:
@ -560,56 +631,28 @@ def mock_completion(
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
"""
try:
if mock_response is None:
mock_response = "This is a mock request"
_handle_mock_timeout(mock_timeout=mock_timeout, timeout=timeout, model=model)
## LOGGING
if logging is not None:
logging.pre_call(
input=messages,
api_key="mock-key",
)
if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif (
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
):
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif (
isinstance(mock_response, str)
and mock_response == "litellm.InternalServerError"
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
"Exception: content_filter_policy"
):
raise litellm.MockException(
status_code=400,
message=mock_response,
llm_provider="azure",
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif isinstance(mock_response, str) and mock_response.startswith(
_handle_mock_potential_exceptions(
mock_response=mock_response,
model=model,
custom_llm_provider=custom_llm_provider,
)
mock_response = cast(
Union[str, dict], mock_response
) # after this point, mock_response is a string or dict
if isinstance(mock_response, str) and mock_response.startswith(
"Exception: mock_streaming_error"
):
mock_response = litellm.MockException(
@ -788,6 +831,7 @@ def completion( # type: ignore # noqa: PLR0915
api_base = kwargs.get("api_base", None)
mock_response = kwargs.get("mock_response", None)
mock_tool_calls = kwargs.get("mock_tool_calls", None)
mock_timeout = cast(Optional[bool], kwargs.get("mock_timeout", None))
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
logger_fn = kwargs.get("logger_fn", None)
verbose = kwargs.get("verbose", False)
@ -1102,7 +1146,8 @@ def completion( # type: ignore # noqa: PLR0915
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if mock_response or mock_tool_calls:
if mock_response or mock_tool_calls or mock_timeout:
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
return mock_completion(
model,
messages,
@ -1114,6 +1159,8 @@ def completion( # type: ignore # noqa: PLR0915
acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider,
mock_timeout=mock_timeout,
timeout=timeout,
)
if custom_llm_provider == "azure":

View file

@ -5,4 +5,4 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
model_info:
mode: audio_transcription

View file

@ -2613,6 +2613,8 @@ class Router:
"content_policy_fallbacks", self.content_policy_fallbacks
)
mock_timeout = kwargs.pop("mock_timeout", None)
try:
self._handle_mock_testing_fallbacks(
kwargs=kwargs,
@ -2622,7 +2624,9 @@ class Router:
content_policy_fallbacks=content_policy_fallbacks,
)
response = await self.async_function_with_retries(*args, **kwargs)
response = await self.async_function_with_retries(
*args, **kwargs, mock_timeout=mock_timeout
)
verbose_router_logger.debug(f"Async Response: {response}")
return response
except Exception as e:
@ -2993,7 +2997,9 @@ class Router:
if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):
response = await response
## PROCESS RESPONSE HEADERS
await self.set_response_headers(response=response, model_group=model_group)
response = await self.set_response_headers(
response=response, model_group=model_group
)
return response
@ -4567,11 +4573,15 @@ class Router:
rpm_limit = None
returned_dict = {}
if tpm_limit is not None and current_tpm is not None:
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - current_tpm
if tpm_limit is not None:
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - (
current_tpm or 0
)
returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
if rpm_limit is not None and current_rpm is not None:
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - current_rpm
if rpm_limit is not None:
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - (
current_rpm or 0
)
returned_dict["x-ratelimit-limit-requests"] = rpm_limit
return returned_dict

View file

@ -1594,6 +1594,7 @@ all_litellm_params = [
"text_completion",
"caching",
"mock_response",
"mock_timeout",
"api_key",
"api_version",
"prompt_id",

View file

@ -11,6 +11,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import time
def test_mock_request():
@ -92,3 +93,86 @@ async def test_async_mock_streaming_request_n_greater_than_1():
# assert (
# complete_response == "LiteLLM is awesome"
# ), f"Unexpected response got {complete_response}"
def test_mock_request_with_mock_timeout():
"""
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
"""
start_time = time.time()
with pytest.raises(litellm.Timeout):
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
timeout=3,
mock_timeout=True,
)
end_time = time.time()
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
def test_router_mock_request_with_mock_timeout():
"""
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
"""
start_time = time.time()
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
)
with pytest.raises(litellm.Timeout):
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
timeout=3,
mock_timeout=True,
)
print(response)
end_time = time.time()
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
def test_router_mock_request_with_mock_timeout_with_fallbacks():
"""
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
"""
litellm.set_verbose = True
start_time = time.time()
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
{
"model_name": "azure-gpt",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
],
fallbacks=[{"gpt-3.5-turbo": ["azure-gpt"]}],
)
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
timeout=3,
num_retries=1,
mock_timeout=True,
)
print(response)
end_time = time.time()
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
assert "gpt-35-turbo" in response.model, "Model should be azure gpt-35-turbo"

View file

@ -344,11 +344,18 @@ async def test_get_remaining_model_group_usage():
]
)
for _ in range(2):
await router.acompletion(
resp = await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
assert (
"x-ratelimit-remaining-tokens" in resp._hidden_params["additional_headers"]
)
assert (
"x-ratelimit-remaining-requests"
in resp._hidden_params["additional_headers"]
)
await asyncio.sleep(1)
remaining_usage = await router.get_remaining_model_group_usage(