Litellm dev 01 20 2025 p3 (#7890)

* fix(router.py): pass stream timeout correctly for non openai / azure models

Fixes https://github.com/BerriAI/litellm/issues/7870

* test(test_router_timeout.py): add test for streaming

* test(test_router_timeout.py): add unit testing for new router functions

* docs(ollama.md): link to section on calling ollama within docker container

* test: remove redundant test

* test: fix test to include timeout value

* docs(config_settings.md): document new router settings param
This commit is contained in:
Krish Dholakia 2025-01-20 21:46:36 -08:00 committed by GitHub
parent 4c1d4acabc
commit 94c9f76767
6 changed files with 197 additions and 9 deletions

View file

@ -356,8 +356,6 @@ for chunk in response:
} }
``` ```
## Support / talk with founders ## Calling Docker Container (host.docker.internal)
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw) [Follow these instructions](https://github.com/BerriAI/litellm/issues/1517#issuecomment-1922022209/)
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai

View file

@ -266,7 +266,8 @@ router_settings:
| polling_interval | (Optional[float]) | frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. | | polling_interval | (Optional[float]) | frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. |
| max_fallbacks | Optional[int] | The maximum number of fallbacks to try before exiting the call. Defaults to 5. | | max_fallbacks | Optional[int] | The maximum number of fallbacks to try before exiting the call. Defaults to 5. |
| default_litellm_params | Optional[dict] | The default litellm parameters to add to all requests (e.g. `temperature`, `max_tokens`). | | default_litellm_params | Optional[dict] | The default litellm parameters to add to all requests (e.g. `temperature`, `max_tokens`). |
| timeout | Optional[float] | The default timeout for a request. | | timeout | Optional[float] | The default timeout for a request. Default is 10 minutes. |
| stream_timeout | Optional[float] | The default timeout for a streaming request. If not set, the 'timeout' value is used. |
| debug_level | Literal["DEBUG", "INFO"] | The debug level for the logging library in the router. Defaults to "INFO". | | debug_level | Literal["DEBUG", "INFO"] | The debug level for the logging library in the router. Defaults to "INFO". |
| client_ttl | int | Time-to-live for cached clients in seconds. Defaults to 3600. | | client_ttl | int | Time-to-live for cached clients in seconds. Defaults to 3600. |
| cache_kwargs | dict | Additional keyword arguments for the cache initialization. | | cache_kwargs | dict | Additional keyword arguments for the cache initialization. |

View file

@ -73,11 +73,14 @@ def make_sync_call(
logging_obj, logging_obj,
streaming_decoder: Optional[CustomStreamingDecoder] = None, streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False, fake_stream: bool = False,
timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
if client is None: if client is None:
client = litellm.module_level_client # Create a new client if none provided client = litellm.module_level_client # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=not fake_stream) response = client.post(
api_base, headers=headers, data=data, stream=not fake_stream, timeout=timeout
)
if response.status_code != 200: if response.status_code != 200:
raise OpenAILikeError(status_code=response.status_code, message=response.read()) raise OpenAILikeError(status_code=response.status_code, message=response.read())
@ -352,6 +355,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
logging_obj=logging_obj, logging_obj=logging_obj,
streaming_decoder=streaming_decoder, streaming_decoder=streaming_decoder,
fake_stream=fake_stream, fake_stream=fake_stream,
timeout=timeout,
) )
# completion_stream.__iter__() # completion_stream.__iter__()
return CustomStreamWrapper( return CustomStreamWrapper(

View file

@ -177,6 +177,7 @@ class Router:
int int
] = None, # max fallbacks to try before exiting the call. Defaults to 5. ] = None, # max fallbacks to try before exiting the call. Defaults to 5.
timeout: Optional[float] = None, timeout: Optional[float] = None,
stream_timeout: Optional[float] = None,
default_litellm_params: Optional[ default_litellm_params: Optional[
dict dict
] = None, # default params for Router.chat.completion.create ] = None, # default params for Router.chat.completion.create
@ -402,6 +403,7 @@ class Router:
self.max_fallbacks = litellm.ROUTER_MAX_FALLBACKS self.max_fallbacks = litellm.ROUTER_MAX_FALLBACKS
self.timeout = timeout or litellm.request_timeout self.timeout = timeout or litellm.request_timeout
self.stream_timeout = stream_timeout
self.retry_after = retry_after self.retry_after = retry_after
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
@ -1045,8 +1047,23 @@ class Router:
return model_client return model_client
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]: def _get_stream_timeout(
"""Helper to get timeout from kwargs or deployment params""" self, kwargs: dict, data: dict
) -> Optional[Union[float, int]]:
"""Helper to get stream timeout from kwargs or deployment params"""
return (
kwargs.get("stream_timeout", None) # the params dynamically set by user
or data.get(
"stream_timeout", None
) # timeout set on litellm_params for this deployment
or self.stream_timeout # timeout set on router
or self.default_litellm_params.get("stream_timeout", None)
)
def _get_non_stream_timeout(
self, kwargs: dict, data: dict
) -> Optional[Union[float, int]]:
"""Helper to get non-stream timeout from kwargs or deployment params"""
timeout = ( timeout = (
kwargs.get("timeout", None) # the params dynamically set by user kwargs.get("timeout", None) # the params dynamically set by user
or kwargs.get("request_timeout", None) # the params dynamically set by user or kwargs.get("request_timeout", None) # the params dynamically set by user
@ -1059,7 +1076,17 @@ class Router:
or self.timeout # timeout set on router or self.timeout # timeout set on router
or self.default_litellm_params.get("timeout", None) or self.default_litellm_params.get("timeout", None)
) )
return timeout
def _get_timeout(self, kwargs: dict, data: dict) -> Optional[Union[float, int]]:
"""Helper to get timeout from kwargs or deployment params"""
timeout: Optional[Union[float, int]] = None
if kwargs.get("stream", False):
timeout = self._get_stream_timeout(kwargs=kwargs, data=data)
if timeout is None:
timeout = self._get_non_stream_timeout(
kwargs=kwargs, data=data
) # default to this if no stream specific timeout set
return timeout return timeout
async def abatch_completion( async def abatch_completion(

View file

@ -381,6 +381,7 @@ def test_completions_streaming_with_sync_http_handler(monkeypatch):
}, },
data=ANY, data=ANY,
stream=True, stream=True,
timeout=ANY,
) )
actual_data = json.loads( actual_data = json.loads(

View file

@ -186,3 +186,160 @@ def test_router_timeout_with_retries_anthropic_model(num_retries, expected_call_
pass pass
assert mock_client.call_count == expected_call_count assert mock_client.call_count == expected_call_count
@pytest.mark.parametrize(
"model",
[
"llama3",
"bedrock-anthropic",
],
)
def test_router_stream_timeout(model):
import os
from dotenv import load_dotenv
import litellm
from litellm.router import Router, RetryPolicy, AllowedFailsPolicy
litellm.set_verbose = True
model_list = [
{
"model_name": "llama3",
"litellm_params": {
"model": "watsonx/meta-llama/llama-3-1-8b-instruct",
"api_base": os.getenv("WATSONX_URL_US_SOUTH"),
"api_key": os.getenv("WATSONX_API_KEY"),
"project_id": os.getenv("WATSONX_PROJECT_ID_US_SOUTH"),
"timeout": 0.01,
"stream_timeout": 0.0000001,
},
},
{
"model_name": "bedrock-anthropic",
"litellm_params": {
"model": "bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"timeout": 0.01,
"stream_timeout": 0.0000001,
},
},
{
"model_name": "llama3-fallback",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
# Initialize router with retry and timeout settings
router = Router(
model_list=model_list,
fallbacks=[
{"llama3": ["llama3-fallback"]},
{"bedrock-anthropic": ["llama3-fallback"]},
],
routing_strategy="latency-based-routing", # 👈 set routing strategy
retry_policy=RetryPolicy(
TimeoutErrorRetries=1, # Number of retries for timeout errors
RateLimitErrorRetries=3,
BadRequestErrorRetries=2,
),
allowed_fails_policy=AllowedFailsPolicy(
TimeoutErrorAllowedFails=2, # Number of timeouts allowed before cooldown
RateLimitErrorAllowedFails=2,
),
cooldown_time=120, # Cooldown time in seconds,
set_verbose=True,
routing_strategy_args={"lowest_latency_buffer": 0.5},
)
print("this fall back does NOT work:")
response = router.completion(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "write a 100 word story about a cat"},
],
temperature=0.6,
max_tokens=500,
stream=True,
)
t = 0
for chunk in response:
assert "llama" not in chunk.model
chunk_text = chunk.choices[0].delta.content or ""
print(chunk_text)
t += 1
if t > 10:
break
@pytest.mark.parametrize(
"stream",
[
True,
False,
],
)
def test_unit_test_streaming_timeout(stream):
import os
from dotenv import load_dotenv
import litellm
from litellm.router import Router, RetryPolicy, AllowedFailsPolicy
litellm.set_verbose = True
model_list = [
{
"model_name": "llama3",
"litellm_params": {
"model": "watsonx/meta-llama/llama-3-1-8b-instruct",
"api_base": os.getenv("WATSONX_URL_US_SOUTH"),
"api_key": os.getenv("WATSONX_API_KEY"),
"project_id": os.getenv("WATSONX_PROJECT_ID_US_SOUTH"),
"timeout": 0.01,
"stream_timeout": 0.0000001,
},
},
{
"model_name": "bedrock-anthropic",
"litellm_params": {
"model": "bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"timeout": 0.01,
"stream_timeout": 0.0000001,
},
},
{
"model_name": "llama3-fallback",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
]
router = Router(model_list=model_list)
stream_timeout = 0.0000001
normal_timeout = 0.01
args = {
"kwargs": {"stream": stream},
"data": {"timeout": normal_timeout, "stream_timeout": stream_timeout},
}
assert router._get_stream_timeout(**args) == stream_timeout
assert router._get_non_stream_timeout(**args) == normal_timeout
stream_timeout_val = router._get_timeout(
kwargs={"stream": stream},
data={"timeout": normal_timeout, "stream_timeout": stream_timeout},
)
if stream:
assert stream_timeout_val == stream_timeout
else:
assert stream_timeout_val == normal_timeout