diff --git a/litellm/router.py b/litellm/router.py index 0015af4db..6ef75b76a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4050,6 +4050,12 @@ class Router: for idx in reversed(invalid_model_indices): _returned_deployments.pop(idx) + ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) + if len(_returned_deployments) > 0: + _returned_deployments = litellm.utils._get_order_filtered_deployments( + _returned_deployments + ) + return _returned_deployments def _common_checks_available_deployment( diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index d76dec25c..02bf9a16b 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -38,6 +38,48 @@ def test_router_sensitive_keys(): assert "special-key" not in str(e) +def test_router_order(): + """ + Asserts for 2 models in a model group, model with order=1 always called first + """ + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + "mock_response": "Hello world", + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad-key", + "mock_response": Exception("this is a bad key"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + allowed_fails=0, + enable_pre_call_checks=True, + ) + + for _ in range(100): + response = router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + + assert isinstance(response, litellm.ModelResponse) + assert response._hidden_params["model_id"] == "1" + + @pytest.mark.parametrize("num_retries", [None, 2]) @pytest.mark.parametrize("max_retries", [None, 4]) def test_router_num_retries_init(num_retries, max_retries): diff --git a/litellm/utils.py b/litellm/utils.py index 2c3f884e9..88958c437 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6197,6 +6197,27 @@ def calculate_max_parallel_requests( return None +def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List: + min_order = min( + ( + deployment["litellm_params"]["order"] + for deployment in healthy_deployments + if "order" in deployment["litellm_params"] + ), + default=None, + ) + + if min_order is not None: + filtered_deployments = [ + deployment + for deployment in healthy_deployments + if deployment["litellm_params"].get("order") == min_order + ] + + return filtered_deployments + return healthy_deployments + + def _get_model_region( custom_llm_provider: str, litellm_params: LiteLLM_Params ) -> Optional[str]: