forked from phoenix/litellm-mirror
Merge pull request #3302 from BerriAI/litellm_default_router_retries
fix(router.py): fix default retry logic
This commit is contained in:
commit
7502cb1aa8
8 changed files with 73 additions and 36 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -51,3 +51,4 @@ loadtest_kub.yaml
|
||||||
litellm/proxy/_new_secret_config.yaml
|
litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_new_secret_config.yaml
|
litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
|
litellm/proxy/_super_secret_config.yaml
|
||||||
|
|
|
@ -447,6 +447,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
|
|
@ -1,23 +1,8 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: text-embedding-3-small
|
|
||||||
litellm_params:
|
|
||||||
model: text-embedding-3-small
|
|
||||||
- model_name: whisper
|
|
||||||
litellm_params:
|
|
||||||
model: azure/azure-whisper
|
|
||||||
api_version: 2024-02-15-preview
|
|
||||||
api_base: os.environ/AZURE_EUROPE_API_BASE
|
|
||||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
|
||||||
model_info:
|
|
||||||
mode: audio_transcription
|
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: gpt-4
|
api_base: http://0.0.0.0:8080
|
||||||
model_name: gpt-4
|
api_key: my-fake-key
|
||||||
- model_name: azure-mistral
|
model: openai/my-fake-model
|
||||||
litellm_params:
|
model_name: fake-openai-endpoint
|
||||||
model: azure/mistral-large-latest
|
router_settings:
|
||||||
api_base: https://Mistral-large-nmefg-serverless.eastus2.inference.ai.azure.com
|
num_retries: 0
|
||||||
api_key: os.environ/AZURE_MISTRAL_API_KEY
|
|
||||||
|
|
||||||
# litellm_settings:
|
|
||||||
# cache: True
|
|
|
@ -50,7 +50,7 @@ class Router:
|
||||||
model_names: List = []
|
model_names: List = []
|
||||||
cache_responses: Optional[bool] = False
|
cache_responses: Optional[bool] = False
|
||||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||||
num_retries: int = 0
|
num_retries: int = openai.DEFAULT_MAX_RETRIES
|
||||||
tenacity = None
|
tenacity = None
|
||||||
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
|
leastbusy_logger: Optional[LeastBusyLoggingHandler] = None
|
||||||
lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None
|
lowesttpm_logger: Optional[LowestTPMLoggingHandler] = None
|
||||||
|
@ -70,7 +70,7 @@ class Router:
|
||||||
] = None, # if you want to cache across model groups
|
] = None, # if you want to cache across model groups
|
||||||
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
||||||
## RELIABILITY ##
|
## RELIABILITY ##
|
||||||
num_retries: int = 0,
|
num_retries: Optional[int] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
default_litellm_params={}, # default params for Router.chat.completion.create
|
default_litellm_params={}, # default params for Router.chat.completion.create
|
||||||
default_max_parallel_requests: Optional[int] = None,
|
default_max_parallel_requests: Optional[int] = None,
|
||||||
|
@ -229,7 +229,12 @@ class Router:
|
||||||
self.failed_calls = (
|
self.failed_calls = (
|
||||||
InMemoryCache()
|
InMemoryCache()
|
||||||
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
||||||
self.num_retries = num_retries or litellm.num_retries or 0
|
|
||||||
|
if num_retries is not None:
|
||||||
|
self.num_retries = num_retries
|
||||||
|
elif litellm.num_retries is not None:
|
||||||
|
self.num_retries = litellm.num_retries
|
||||||
|
|
||||||
self.timeout = timeout or litellm.request_timeout
|
self.timeout = timeout or litellm.request_timeout
|
||||||
|
|
||||||
self.retry_after = retry_after
|
self.retry_after = retry_after
|
||||||
|
@ -428,6 +433,7 @@ class Router:
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._acompletion
|
kwargs["original_function"] = self._acompletion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
|
||||||
timeout = kwargs.get("request_timeout", self.timeout)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
|
@ -1415,10 +1421,12 @@ class Router:
|
||||||
context_window_fallbacks = kwargs.pop(
|
context_window_fallbacks = kwargs.pop(
|
||||||
"context_window_fallbacks", self.context_window_fallbacks
|
"context_window_fallbacks", self.context_window_fallbacks
|
||||||
)
|
)
|
||||||
verbose_router_logger.debug(
|
|
||||||
f"async function w/ retries: original_function - {original_function}"
|
|
||||||
)
|
|
||||||
num_retries = kwargs.pop("num_retries")
|
num_retries = kwargs.pop("num_retries")
|
||||||
|
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = await original_function(*args, **kwargs)
|
response = await original_function(*args, **kwargs)
|
||||||
|
@ -2004,7 +2012,9 @@ class Router:
|
||||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||||
litellm_params["stream_timeout"] = stream_timeout
|
litellm_params["stream_timeout"] = stream_timeout
|
||||||
|
|
||||||
max_retries = litellm_params.pop("max_retries", 2)
|
max_retries = litellm_params.pop(
|
||||||
|
"max_retries", 0
|
||||||
|
) # router handles retry logic
|
||||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||||
max_retries = litellm.get_secret(max_retries_env_name)
|
max_retries = litellm.get_secret(max_retries_env_name)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests litellm router
|
# This tests litellm router
|
||||||
|
|
||||||
import sys, os, time
|
import sys, os, time, openai
|
||||||
import traceback, asyncio
|
import traceback, asyncio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -19,6 +19,44 @@ import os, httpx
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_retries", [None, 2])
|
||||||
|
@pytest.mark.parametrize("max_retries", [None, 4])
|
||||||
|
def test_router_num_retries_init(num_retries, max_retries):
|
||||||
|
"""
|
||||||
|
- test when num_retries set v/s not
|
||||||
|
- test client value when max retries set v/s not
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"max_retries": max_retries,
|
||||||
|
},
|
||||||
|
"model_info": {"id": 12345},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
num_retries=num_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_retries is not None:
|
||||||
|
assert router.num_retries == num_retries
|
||||||
|
else:
|
||||||
|
assert router.num_retries == openai.DEFAULT_MAX_RETRIES
|
||||||
|
|
||||||
|
model_client = router._get_client(
|
||||||
|
{"model_info": {"id": 12345}}, client_type="async", kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_retries is not None:
|
||||||
|
assert getattr(model_client, "max_retries") == max_retries
|
||||||
|
else:
|
||||||
|
assert getattr(model_client, "max_retries") == 0
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
|
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
|
||||||
)
|
)
|
||||||
|
|
|
@ -258,6 +258,7 @@ def test_sync_fallbacks_embeddings():
|
||||||
model_list=model_list,
|
model_list=model_list,
|
||||||
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
|
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
|
||||||
set_verbose=False,
|
set_verbose=False,
|
||||||
|
num_retries=0,
|
||||||
)
|
)
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
|
@ -393,7 +394,7 @@ def test_dynamic_fallbacks_sync():
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list, set_verbose=True)
|
router = Router(model_list=model_list, set_verbose=True, num_retries=0)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
kwargs["model"] = "azure/gpt-3.5-turbo"
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
|
|
@ -78,7 +78,8 @@ def test_hanging_request_azure():
|
||||||
"model_name": "openai-gpt",
|
"model_name": "openai-gpt",
|
||||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
|
num_retries=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0]
|
encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0]
|
||||||
|
@ -131,7 +132,8 @@ def test_hanging_request_openai():
|
||||||
"model_name": "openai-gpt",
|
"model_name": "openai-gpt",
|
||||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
|
num_retries=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0]
|
encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0]
|
||||||
|
@ -189,6 +191,7 @@ def test_timeout_streaming():
|
||||||
# test_timeout_streaming()
|
# test_timeout_streaming()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
def test_timeout_ollama():
|
def test_timeout_ollama():
|
||||||
# this Will Raise a timeout
|
# this Will Raise a timeout
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -110,7 +110,7 @@ class LiteLLM_Params(BaseModel):
|
||||||
stream_timeout: Optional[Union[float, str]] = (
|
stream_timeout: Optional[Union[float, str]] = (
|
||||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
)
|
)
|
||||||
max_retries: int = 2 # follows openai default of 2
|
max_retries: Optional[int] = None
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
|
@ -148,9 +148,7 @@ class LiteLLM_Params(BaseModel):
|
||||||
args.pop("self", None)
|
args.pop("self", None)
|
||||||
args.pop("params", None)
|
args.pop("params", None)
|
||||||
args.pop("__class__", None)
|
args.pop("__class__", None)
|
||||||
if max_retries is None:
|
if max_retries is not None and isinstance(max_retries, str):
|
||||||
max_retries = 2
|
|
||||||
elif isinstance(max_retries, str):
|
|
||||||
max_retries = int(max_retries) # cast to int
|
max_retries = int(max_retries) # cast to int
|
||||||
super().__init__(max_retries=max_retries, **args, **params)
|
super().__init__(max_retries=max_retries, **args, **params)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue