mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
LiteLLM Minor Fixes & Improvements (12/23/2024) - P2 (#7386)
* fix(main.py): support 'mock_timeout=true' param allows mock requests on proxy to have a time delay, for testing * fix(main.py): ensure mock timeouts raise litellm.Timeout error triggers retry/fallbacks * fix: fix fallback + mock timeout testing * fix(router.py): always return remaining tpm/rpm limits, if limits are known allows for rate limit headers to be guaranteed * docs(timeout.md): add docs on mock timeout = true * fix(main.py): fix linting errors * test: fix test
This commit is contained in:
parent
a89b0d5c39
commit
51f9f75c85
7 changed files with 223 additions and 54 deletions
|
@ -176,3 +176,23 @@ print(response)
|
|||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Testing timeout handling
|
||||
|
||||
To test if your retry/fallback logic can handle timeouts, you can set `mock_timeout=True` for testing.
|
||||
|
||||
This is currently only supported on `/chat/completions` and `/completions` endpoints. Please [let us know](https://github.com/BerriAI/litellm/issues) if you need this for other endpoints.
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
--data-raw '{
|
||||
"model": "gemini/gemini-1.5-flash",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||
],
|
||||
"mock_timeout": true # 👈 KEY CHANGE
|
||||
}'
|
||||
```
|
||||
|
|
131
litellm/main.py
131
litellm/main.py
|
@ -527,45 +527,11 @@ async def _async_streaming(response, model, custom_llm_provider, args):
|
|||
)
|
||||
|
||||
|
||||
def mock_completion(
|
||||
def _handle_mock_potential_exceptions(
|
||||
mock_response: Union[str, Exception, dict],
|
||||
model: str,
|
||||
messages: List,
|
||||
stream: Optional[bool] = False,
|
||||
n: Optional[int] = None,
|
||||
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||
mock_tool_calls: Optional[List] = None,
|
||||
logging=None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generate a mock completion response for testing or debugging purposes.
|
||||
|
||||
This is a helper function that simulates the response structure of the OpenAI completion API.
|
||||
|
||||
Parameters:
|
||||
model (str): The name of the language model for which the mock response is generated.
|
||||
messages (List): A list of message objects representing the conversation context.
|
||||
stream (bool, optional): If True, returns a mock streaming response (default is False).
|
||||
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
|
||||
**kwargs: Additional keyword arguments that can be used but are not required.
|
||||
|
||||
Returns:
|
||||
litellm.ModelResponse: A ModelResponse simulating a completion response with the specified model, messages, and mock response.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the generation of the mock completion response.
|
||||
Note:
|
||||
- This function is intended for testing or debugging purposes to generate mock completion responses.
|
||||
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
||||
"""
|
||||
try:
|
||||
## LOGGING
|
||||
if logging is not None:
|
||||
logging.pre_call(
|
||||
input=messages,
|
||||
api_key="mock-key",
|
||||
)
|
||||
if isinstance(mock_response, Exception):
|
||||
if isinstance(mock_response, openai.APIError):
|
||||
raise mock_response
|
||||
|
@ -578,9 +544,7 @@ def mock_completion(
|
|||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif (
|
||||
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
|
||||
):
|
||||
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
|
||||
raise litellm.RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider=getattr(
|
||||
|
@ -609,7 +573,86 @@ def mock_completion(
|
|||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
|
||||
|
||||
def _handle_mock_timeout(
|
||||
mock_timeout: Optional[bool],
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]],
|
||||
model: str,
|
||||
):
|
||||
if mock_timeout is True and timeout is not None:
|
||||
if isinstance(timeout, float):
|
||||
time.sleep(timeout)
|
||||
elif isinstance(timeout, str):
|
||||
time.sleep(float(timeout))
|
||||
elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None:
|
||||
time.sleep(timeout.connect)
|
||||
raise litellm.Timeout(
|
||||
message="This is a mock timeout error",
|
||||
llm_provider="openai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def mock_completion(
|
||||
model: str,
|
||||
messages: List,
|
||||
stream: Optional[bool] = False,
|
||||
n: Optional[int] = None,
|
||||
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||
mock_tool_calls: Optional[List] = None,
|
||||
mock_timeout: Optional[bool] = False,
|
||||
logging=None,
|
||||
custom_llm_provider=None,
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a mock completion response for testing or debugging purposes.
|
||||
|
||||
This is a helper function that simulates the response structure of the OpenAI completion API.
|
||||
|
||||
Parameters:
|
||||
model (str): The name of the language model for which the mock response is generated.
|
||||
messages (List): A list of message objects representing the conversation context.
|
||||
stream (bool, optional): If True, returns a mock streaming response (default is False).
|
||||
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
|
||||
mock_timeout (bool, optional): If True, the mock response will be a timeout error (default is False).
|
||||
timeout (float, optional): The timeout value to use for the mock response (default is None).
|
||||
**kwargs: Additional keyword arguments that can be used but are not required.
|
||||
|
||||
Returns:
|
||||
litellm.ModelResponse: A ModelResponse simulating a completion response with the specified model, messages, and mock response.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the generation of the mock completion response.
|
||||
Note:
|
||||
- This function is intended for testing or debugging purposes to generate mock completion responses.
|
||||
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
||||
"""
|
||||
try:
|
||||
if mock_response is None:
|
||||
mock_response = "This is a mock request"
|
||||
|
||||
_handle_mock_timeout(mock_timeout=mock_timeout, timeout=timeout, model=model)
|
||||
|
||||
## LOGGING
|
||||
if logging is not None:
|
||||
logging.pre_call(
|
||||
input=messages,
|
||||
api_key="mock-key",
|
||||
)
|
||||
|
||||
_handle_mock_potential_exceptions(
|
||||
mock_response=mock_response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
mock_response = cast(
|
||||
Union[str, dict], mock_response
|
||||
) # after this point, mock_response is a string or dict
|
||||
if isinstance(mock_response, str) and mock_response.startswith(
|
||||
"Exception: mock_streaming_error"
|
||||
):
|
||||
mock_response = litellm.MockException(
|
||||
|
@ -788,6 +831,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_base = kwargs.get("api_base", None)
|
||||
mock_response = kwargs.get("mock_response", None)
|
||||
mock_tool_calls = kwargs.get("mock_tool_calls", None)
|
||||
mock_timeout = cast(Optional[bool], kwargs.get("mock_timeout", None))
|
||||
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
@ -1102,7 +1146,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if mock_response or mock_tool_calls:
|
||||
if mock_response or mock_tool_calls or mock_timeout:
|
||||
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
|
||||
return mock_completion(
|
||||
model,
|
||||
messages,
|
||||
|
@ -1114,6 +1159,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
acompletion=acompletion,
|
||||
mock_delay=kwargs.get("mock_delay", None),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
mock_timeout=mock_timeout,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
|
|
|
@ -2613,6 +2613,8 @@ class Router:
|
|||
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||
)
|
||||
|
||||
mock_timeout = kwargs.pop("mock_timeout", None)
|
||||
|
||||
try:
|
||||
self._handle_mock_testing_fallbacks(
|
||||
kwargs=kwargs,
|
||||
|
@ -2622,7 +2624,9 @@ class Router:
|
|||
content_policy_fallbacks=content_policy_fallbacks,
|
||||
)
|
||||
|
||||
response = await self.async_function_with_retries(*args, **kwargs)
|
||||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs, mock_timeout=mock_timeout
|
||||
)
|
||||
verbose_router_logger.debug(f"Async Response: {response}")
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -2993,7 +2997,9 @@ class Router:
|
|||
if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):
|
||||
response = await response
|
||||
## PROCESS RESPONSE HEADERS
|
||||
await self.set_response_headers(response=response, model_group=model_group)
|
||||
response = await self.set_response_headers(
|
||||
response=response, model_group=model_group
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
@ -4567,11 +4573,15 @@ class Router:
|
|||
rpm_limit = None
|
||||
|
||||
returned_dict = {}
|
||||
if tpm_limit is not None and current_tpm is not None:
|
||||
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - current_tpm
|
||||
if tpm_limit is not None:
|
||||
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - (
|
||||
current_tpm or 0
|
||||
)
|
||||
returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
|
||||
if rpm_limit is not None and current_rpm is not None:
|
||||
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - current_rpm
|
||||
if rpm_limit is not None:
|
||||
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - (
|
||||
current_rpm or 0
|
||||
)
|
||||
returned_dict["x-ratelimit-limit-requests"] = rpm_limit
|
||||
|
||||
return returned_dict
|
||||
|
|
|
@ -1594,6 +1594,7 @@ all_litellm_params = [
|
|||
"text_completion",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"mock_timeout",
|
||||
"api_key",
|
||||
"api_version",
|
||||
"prompt_id",
|
||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
import time
|
||||
|
||||
|
||||
def test_mock_request():
|
||||
|
@ -92,3 +93,86 @@ async def test_async_mock_streaming_request_n_greater_than_1():
|
|||
# assert (
|
||||
# complete_response == "LiteLLM is awesome"
|
||||
# ), f"Unexpected response got {complete_response}"
|
||||
|
||||
|
||||
def test_mock_request_with_mock_timeout():
|
||||
"""
|
||||
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
with pytest.raises(litellm.Timeout):
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
|
||||
timeout=3,
|
||||
mock_timeout=True,
|
||||
)
|
||||
end_time = time.time()
|
||||
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
|
||||
|
||||
|
||||
def test_router_mock_request_with_mock_timeout():
|
||||
"""
|
||||
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
with pytest.raises(litellm.Timeout):
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
|
||||
timeout=3,
|
||||
mock_timeout=True,
|
||||
)
|
||||
print(response)
|
||||
end_time = time.time()
|
||||
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
|
||||
|
||||
|
||||
def test_router_mock_request_with_mock_timeout_with_fallbacks():
|
||||
"""
|
||||
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
start_time = time.time()
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-gpt",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
],
|
||||
fallbacks=[{"gpt-3.5-turbo": ["azure-gpt"]}],
|
||||
)
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
|
||||
timeout=3,
|
||||
num_retries=1,
|
||||
mock_timeout=True,
|
||||
)
|
||||
print(response)
|
||||
end_time = time.time()
|
||||
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
|
||||
assert "gpt-35-turbo" in response.model, "Model should be azure gpt-35-turbo"
|
||||
|
|
|
@ -344,11 +344,18 @@ async def test_get_remaining_model_group_usage():
|
|||
]
|
||||
)
|
||||
for _ in range(2):
|
||||
await router.acompletion(
|
||||
resp = await router.acompletion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="Hello, I'm good.",
|
||||
)
|
||||
assert (
|
||||
"x-ratelimit-remaining-tokens" in resp._hidden_params["additional_headers"]
|
||||
)
|
||||
assert (
|
||||
"x-ratelimit-remaining-requests"
|
||||
in resp._hidden_params["additional_headers"]
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
remaining_usage = await router.get_remaining_model_group_usage(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue