forked from phoenix/litellm-mirror
Merge pull request #4046 from BerriAI/litellm_router_order
feat(router.py): enable settting 'order' for a deployment in model list
This commit is contained in:
commit
1742141fb6
3 changed files with 69 additions and 0 deletions
|
@ -4050,6 +4050,12 @@ class Router:
|
|||
for idx in reversed(invalid_model_indices):
|
||||
_returned_deployments.pop(idx)
|
||||
|
||||
## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2)
|
||||
if len(_returned_deployments) > 0:
|
||||
_returned_deployments = litellm.utils._get_order_filtered_deployments(
|
||||
_returned_deployments
|
||||
)
|
||||
|
||||
return _returned_deployments
|
||||
|
||||
def _common_checks_available_deployment(
|
||||
|
|
|
@ -38,6 +38,48 @@ def test_router_sensitive_keys():
|
|||
assert "special-key" not in str(e)
|
||||
|
||||
|
||||
def test_router_order():
|
||||
"""
|
||||
Asserts for 2 models in a model group, model with order=1 always called first
|
||||
"""
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"mock_response": "Hello world",
|
||||
"order": 1,
|
||||
},
|
||||
"model_info": {"id": "1"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": "bad-key",
|
||||
"mock_response": Exception("this is a bad key"),
|
||||
"order": 2,
|
||||
},
|
||||
"model_info": {"id": "2"},
|
||||
},
|
||||
],
|
||||
num_retries=0,
|
||||
allowed_fails=0,
|
||||
enable_pre_call_checks=True,
|
||||
)
|
||||
|
||||
for _ in range(100):
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert response._hidden_params["model_id"] == "1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_retries", [None, 2])
|
||||
@pytest.mark.parametrize("max_retries", [None, 4])
|
||||
def test_router_num_retries_init(num_retries, max_retries):
|
||||
|
|
|
@ -6197,6 +6197,27 @@ def calculate_max_parallel_requests(
|
|||
return None
|
||||
|
||||
|
||||
def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List:
|
||||
min_order = min(
|
||||
(
|
||||
deployment["litellm_params"]["order"]
|
||||
for deployment in healthy_deployments
|
||||
if "order" in deployment["litellm_params"]
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
if min_order is not None:
|
||||
filtered_deployments = [
|
||||
deployment
|
||||
for deployment in healthy_deployments
|
||||
if deployment["litellm_params"].get("order") == min_order
|
||||
]
|
||||
|
||||
return filtered_deployments
|
||||
return healthy_deployments
|
||||
|
||||
|
||||
def _get_model_region(
|
||||
custom_llm_provider: str, litellm_params: LiteLLM_Params
|
||||
) -> Optional[str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue