mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
LiteLLM Minor Fixes & Improvements (12/23/2024) - P2 (#7386)
* fix(main.py): support 'mock_timeout=true' param allows mock requests on proxy to have a time delay, for testing * fix(main.py): ensure mock timeouts raise litellm.Timeout error triggers retry/fallbacks * fix: fix fallback + mock timeout testing * fix(router.py): always return remaining tpm/rpm limits, if limits are known allows for rate limit headers to be guaranteed * docs(timeout.md): add docs on mock timeout = true * fix(main.py): fix linting errors * test: fix test
This commit is contained in:
parent
a89b0d5c39
commit
51f9f75c85
7 changed files with 223 additions and 54 deletions
|
@ -175,4 +175,24 @@ print(response)
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## Testing timeout handling
|
||||||
|
|
||||||
|
To test if your retry/fallback logic can handle timeouts, you can set `mock_timeout=True` for testing.
|
||||||
|
|
||||||
|
This is currently only supported on `/chat/completions` and `/completions` endpoints. Please [let us know](https://github.com/BerriAI/litellm/issues) if you need this for other endpoints.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
--data-raw '{
|
||||||
|
"model": "gemini/gemini-1.5-flash",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||||
|
],
|
||||||
|
"mock_timeout": true # 👈 KEY CHANGE
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
137
litellm/main.py
137
litellm/main.py
|
@ -527,6 +527,73 @@ async def _async_streaming(response, model, custom_llm_provider, args):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_mock_potential_exceptions(
|
||||||
|
mock_response: Union[str, Exception, dict],
|
||||||
|
model: str,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if isinstance(mock_response, Exception):
|
||||||
|
if isinstance(mock_response, openai.APIError):
|
||||||
|
raise mock_response
|
||||||
|
raise litellm.MockException(
|
||||||
|
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||||
|
message=getattr(mock_response, "text", str(mock_response)),
|
||||||
|
llm_provider=getattr(
|
||||||
|
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||||
|
), # type: ignore
|
||||||
|
model=model, # type: ignore
|
||||||
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
|
)
|
||||||
|
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="this is a mock rate limit error",
|
||||||
|
llm_provider=getattr(
|
||||||
|
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||||
|
), # type: ignore
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(mock_response, str)
|
||||||
|
and mock_response == "litellm.InternalServerError"
|
||||||
|
):
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message="this is a mock internal server error",
|
||||||
|
llm_provider=getattr(
|
||||||
|
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||||
|
), # type: ignore
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||||
|
"Exception: content_filter_policy"
|
||||||
|
):
|
||||||
|
raise litellm.MockException(
|
||||||
|
status_code=400,
|
||||||
|
message=mock_response,
|
||||||
|
llm_provider="azure",
|
||||||
|
model=model, # type: ignore
|
||||||
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_mock_timeout(
|
||||||
|
mock_timeout: Optional[bool],
|
||||||
|
timeout: Optional[Union[float, str, httpx.Timeout]],
|
||||||
|
model: str,
|
||||||
|
):
|
||||||
|
if mock_timeout is True and timeout is not None:
|
||||||
|
if isinstance(timeout, float):
|
||||||
|
time.sleep(timeout)
|
||||||
|
elif isinstance(timeout, str):
|
||||||
|
time.sleep(float(timeout))
|
||||||
|
elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None:
|
||||||
|
time.sleep(timeout.connect)
|
||||||
|
raise litellm.Timeout(
|
||||||
|
message="This is a mock timeout error",
|
||||||
|
llm_provider="openai",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_completion(
|
def mock_completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List,
|
messages: List,
|
||||||
|
@ -534,8 +601,10 @@ def mock_completion(
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
mock_response: Union[str, Exception, dict] = "This is a mock request",
|
||||||
mock_tool_calls: Optional[List] = None,
|
mock_tool_calls: Optional[List] = None,
|
||||||
|
mock_timeout: Optional[bool] = False,
|
||||||
logging=None,
|
logging=None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
|
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -548,6 +617,8 @@ def mock_completion(
|
||||||
messages (List): A list of message objects representing the conversation context.
|
messages (List): A list of message objects representing the conversation context.
|
||||||
stream (bool, optional): If True, returns a mock streaming response (default is False).
|
stream (bool, optional): If True, returns a mock streaming response (default is False).
|
||||||
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
|
mock_response (str, optional): The content of the mock response (default is "This is a mock request").
|
||||||
|
mock_timeout (bool, optional): If True, the mock response will be a timeout error (default is False).
|
||||||
|
timeout (float, optional): The timeout value to use for the mock response (default is None).
|
||||||
**kwargs: Additional keyword arguments that can be used but are not required.
|
**kwargs: Additional keyword arguments that can be used but are not required.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -560,56 +631,28 @@ def mock_completion(
|
||||||
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if mock_response is None:
|
||||||
|
mock_response = "This is a mock request"
|
||||||
|
|
||||||
|
_handle_mock_timeout(mock_timeout=mock_timeout, timeout=timeout, model=model)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
if logging is not None:
|
if logging is not None:
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key="mock-key",
|
api_key="mock-key",
|
||||||
)
|
)
|
||||||
if isinstance(mock_response, Exception):
|
|
||||||
if isinstance(mock_response, openai.APIError):
|
_handle_mock_potential_exceptions(
|
||||||
raise mock_response
|
mock_response=mock_response,
|
||||||
raise litellm.MockException(
|
model=model,
|
||||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
custom_llm_provider=custom_llm_provider,
|
||||||
message=getattr(mock_response, "text", str(mock_response)),
|
)
|
||||||
llm_provider=getattr(
|
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
mock_response = cast(
|
||||||
), # type: ignore
|
Union[str, dict], mock_response
|
||||||
model=model, # type: ignore
|
) # after this point, mock_response is a string or dict
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
if isinstance(mock_response, str) and mock_response.startswith(
|
||||||
)
|
|
||||||
elif (
|
|
||||||
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
|
|
||||||
):
|
|
||||||
raise litellm.RateLimitError(
|
|
||||||
message="this is a mock rate limit error",
|
|
||||||
llm_provider=getattr(
|
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
|
||||||
), # type: ignore
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
isinstance(mock_response, str)
|
|
||||||
and mock_response == "litellm.InternalServerError"
|
|
||||||
):
|
|
||||||
raise litellm.InternalServerError(
|
|
||||||
message="this is a mock internal server error",
|
|
||||||
llm_provider=getattr(
|
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
|
||||||
), # type: ignore
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
|
||||||
"Exception: content_filter_policy"
|
|
||||||
):
|
|
||||||
raise litellm.MockException(
|
|
||||||
status_code=400,
|
|
||||||
message=mock_response,
|
|
||||||
llm_provider="azure",
|
|
||||||
model=model, # type: ignore
|
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
|
||||||
)
|
|
||||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
|
||||||
"Exception: mock_streaming_error"
|
"Exception: mock_streaming_error"
|
||||||
):
|
):
|
||||||
mock_response = litellm.MockException(
|
mock_response = litellm.MockException(
|
||||||
|
@ -788,6 +831,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_base = kwargs.get("api_base", None)
|
api_base = kwargs.get("api_base", None)
|
||||||
mock_response = kwargs.get("mock_response", None)
|
mock_response = kwargs.get("mock_response", None)
|
||||||
mock_tool_calls = kwargs.get("mock_tool_calls", None)
|
mock_tool_calls = kwargs.get("mock_tool_calls", None)
|
||||||
|
mock_timeout = cast(Optional[bool], kwargs.get("mock_timeout", None))
|
||||||
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
|
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
verbose = kwargs.get("verbose", False)
|
verbose = kwargs.get("verbose", False)
|
||||||
|
@ -1102,7 +1146,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
if mock_response or mock_tool_calls:
|
if mock_response or mock_tool_calls or mock_timeout:
|
||||||
|
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
|
||||||
return mock_completion(
|
return mock_completion(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
@ -1114,6 +1159,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
mock_delay=kwargs.get("mock_delay", None),
|
mock_delay=kwargs.get("mock_delay", None),
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
mock_timeout=mock_timeout,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
|
|
|
@ -5,4 +5,4 @@ model_list:
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
model_info:
|
||||||
mode: audio_transcription
|
mode: audio_transcription
|
||||||
|
|
|
@ -2613,6 +2613,8 @@ class Router:
|
||||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_timeout = kwargs.pop("mock_timeout", None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._handle_mock_testing_fallbacks(
|
self._handle_mock_testing_fallbacks(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -2622,7 +2624,9 @@ class Router:
|
||||||
content_policy_fallbacks=content_policy_fallbacks,
|
content_policy_fallbacks=content_policy_fallbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self.async_function_with_retries(*args, **kwargs)
|
response = await self.async_function_with_retries(
|
||||||
|
*args, **kwargs, mock_timeout=mock_timeout
|
||||||
|
)
|
||||||
verbose_router_logger.debug(f"Async Response: {response}")
|
verbose_router_logger.debug(f"Async Response: {response}")
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2993,7 +2997,9 @@ class Router:
|
||||||
if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):
|
if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):
|
||||||
response = await response
|
response = await response
|
||||||
## PROCESS RESPONSE HEADERS
|
## PROCESS RESPONSE HEADERS
|
||||||
await self.set_response_headers(response=response, model_group=model_group)
|
response = await self.set_response_headers(
|
||||||
|
response=response, model_group=model_group
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -4567,11 +4573,15 @@ class Router:
|
||||||
rpm_limit = None
|
rpm_limit = None
|
||||||
|
|
||||||
returned_dict = {}
|
returned_dict = {}
|
||||||
if tpm_limit is not None and current_tpm is not None:
|
if tpm_limit is not None:
|
||||||
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - current_tpm
|
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - (
|
||||||
|
current_tpm or 0
|
||||||
|
)
|
||||||
returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
|
returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
|
||||||
if rpm_limit is not None and current_rpm is not None:
|
if rpm_limit is not None:
|
||||||
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - current_rpm
|
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - (
|
||||||
|
current_rpm or 0
|
||||||
|
)
|
||||||
returned_dict["x-ratelimit-limit-requests"] = rpm_limit
|
returned_dict["x-ratelimit-limit-requests"] = rpm_limit
|
||||||
|
|
||||||
return returned_dict
|
return returned_dict
|
||||||
|
|
|
@ -1594,6 +1594,7 @@ all_litellm_params = [
|
||||||
"text_completion",
|
"text_completion",
|
||||||
"caching",
|
"caching",
|
||||||
"mock_response",
|
"mock_response",
|
||||||
|
"mock_timeout",
|
||||||
"api_key",
|
"api_key",
|
||||||
"api_version",
|
"api_version",
|
||||||
"prompt_id",
|
"prompt_id",
|
||||||
|
|
|
@ -11,6 +11,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
import litellm
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
def test_mock_request():
|
def test_mock_request():
|
||||||
|
@ -92,3 +93,86 @@ async def test_async_mock_streaming_request_n_greater_than_1():
|
||||||
# assert (
|
# assert (
|
||||||
# complete_response == "LiteLLM is awesome"
|
# complete_response == "LiteLLM is awesome"
|
||||||
# ), f"Unexpected response got {complete_response}"
|
# ), f"Unexpected response got {complete_response}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_request_with_mock_timeout():
|
||||||
|
"""
|
||||||
|
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
with pytest.raises(litellm.Timeout):
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
|
||||||
|
timeout=3,
|
||||||
|
mock_timeout=True,
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_mock_request_with_mock_timeout():
|
||||||
|
"""
|
||||||
|
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
with pytest.raises(litellm.Timeout):
|
||||||
|
response = router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
|
||||||
|
timeout=3,
|
||||||
|
mock_timeout=True,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
end_time = time.time()
|
||||||
|
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_mock_request_with_mock_timeout_with_fallbacks():
|
||||||
|
"""
|
||||||
|
Allow user to set 'mock_timeout = True', this allows for testing if fallbacks/retries are working on timeouts.
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
start_time = time.time()
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-gpt",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
fallbacks=[{"gpt-3.5-turbo": ["azure-gpt"]}],
|
||||||
|
)
|
||||||
|
response = router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, I'm a mock request"}],
|
||||||
|
timeout=3,
|
||||||
|
num_retries=1,
|
||||||
|
mock_timeout=True,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
end_time = time.time()
|
||||||
|
assert end_time - start_time >= 3, f"Time taken: {end_time - start_time}"
|
||||||
|
assert "gpt-35-turbo" in response.model, "Model should be azure gpt-35-turbo"
|
||||||
|
|
|
@ -344,11 +344,18 @@ async def test_get_remaining_model_group_usage():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
await router.acompletion(
|
resp = await router.acompletion(
|
||||||
model="gemini/gemini-1.5-flash",
|
model="gemini/gemini-1.5-flash",
|
||||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
mock_response="Hello, I'm good.",
|
mock_response="Hello, I'm good.",
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
"x-ratelimit-remaining-tokens" in resp._hidden_params["additional_headers"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
in resp._hidden_params["additional_headers"]
|
||||||
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
remaining_usage = await router.get_remaining_model_group_usage(
|
remaining_usage = await router.get_remaining_model_group_usage(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue