From 1bb1f70a478138ad3b6570ed0d9b79ad9cf4df69 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 17 Sep 2024 20:24:28 -0700 Subject: [PATCH] [Fix] Router/ Proxy - Tag Based routing, raise correct error when no deployments found and tag filtering is on (#5745) * fix tag routing - raise correct error when no model with tag based routing * fix error string from tag based routing * test router tag based routing * raise 401 error when no tags avialable for deploymen * linting fix --- litellm/proxy/_types.py | 4 +- litellm/proxy/proxy_config.yaml | 3 ++ litellm/router.py | 1 + litellm/router_strategy/tag_based_routing.py | 13 ++++- litellm/tests/test_router_tag_routing.py | 57 ++++++++++++++++++++ litellm/types/router.py | 3 ++ 6 files changed, 79 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c75ea4760..adcd3f89d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator from typing_extensions import Annotated, TypedDict -from litellm.types.router import UpdateRouterConfig +from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.utils import ProviderField if TYPE_CHECKING: @@ -1826,6 +1826,8 @@ class ProxyException(Exception): or "No deployments available" in self.message ): self.code = "429" + elif RouterErrors.no_deployments_with_tag_routing.value in self.message: + self.code = "401" def to_dict(self) -> dict: """Converts the ProxyException instance to a dictionary.""" diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 1a2dc3cf8..9939e92a9 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -19,9 +19,12 @@ model_list: model: openai/429 api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app + tags: ["fake"] general_settings: master_key: sk-1234 +router_settings: + enable_tag_filtering: true diff --git a/litellm/router.py b/litellm/router.py index d31646203..ec6159da4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5180,6 +5180,7 @@ class Router: # check if user wants to do tag based routing healthy_deployments = await get_deployments_for_tag( # type: ignore llm_router_instance=self, + model=model, request_kwargs=request_kwargs, healthy_deployments=healthy_deployments, ) diff --git a/litellm/router_strategy/tag_based_routing.py b/litellm/router_strategy/tag_based_routing.py index deda1bd77..9f8cd9ac5 100644 --- a/litellm/router_strategy/tag_based_routing.py +++ b/litellm/router_strategy/tag_based_routing.py @@ -9,7 +9,7 @@ Use this to route requests between Teams from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union from litellm._logging import verbose_logger -from litellm.types.router import DeploymentTypedDict +from litellm.types.router import DeploymentTypedDict, RouterErrors if TYPE_CHECKING: from litellm.router import Router as _Router @@ -21,9 +21,15 @@ else: async def get_deployments_for_tag( llm_router_instance: LitellmRouter, + model: str, # used to raise the correct error healthy_deployments: Union[List[Any], Dict[Any, Any]], request_kwargs: Optional[Dict[Any, Any]] = None, ): + """ + Returns a list of deployments that match the requested model and tags in the request. + + Executes tag based filtering based on the tags in request metadata and the tags on the deployments + """ if llm_router_instance.enable_tag_filtering is not True: return healthy_deployments @@ -80,6 +86,11 @@ async def get_deployments_for_tag( ) new_healthy_deployments.append(deployment) + if len(new_healthy_deployments) == 0: + raise ValueError( + f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}" + ) + return new_healthy_deployments # for Untagged requests use default deployments if set diff --git a/litellm/tests/test_router_tag_routing.py b/litellm/tests/test_router_tag_routing.py index f71a9b762..4432db530 100644 --- a/litellm/tests/test_router_tag_routing.py +++ b/litellm/tests/test_router_tag_routing.py @@ -160,3 +160,60 @@ async def test_default_tagged_deployments(): print("response_extra_info: ", response_extra_info) assert response_extra_info["model_id"] == "default-model" + + +@pytest.mark.asyncio() +async def test_error_from_tag_routing(): + """ + Tests the correct error raised when no deployments found for tag + """ + import logging + + from litellm._logging import verbose_logger + + verbose_logger.setLevel(logging.DEBUG) + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4o", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + }, + "model_info": {"id": "default-model"}, + }, + { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4o", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + }, + "model_info": {"id": "default-model-2"}, + }, + { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4o-mini", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + "tags": ["teamA"], + }, + "model_info": {"id": "very-expensive-model"}, + }, + ], + enable_tag_filtering=True, + ) + + try: + response = await router.acompletion( + model="gpt-4", + messages=[{"role": "user", "content": "Tell me a joke."}], + metadata={"tags": ["paid"]}, + ) + + pytest.fail("this should have failed - expected it to fail") + except Exception as e: + from litellm.types.router import RouterErrors + + assert RouterErrors.no_deployments_with_tag_routing.value in str(e) + print("got expected exception = ", e) + pass diff --git a/litellm/types/router.py b/litellm/types/router.py index 8c8c6a3aa..67870c313 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -414,6 +414,9 @@ class RouterErrors(enum.Enum): user_defined_ratelimit_error = "Deployment over user-defined ratelimit." no_deployments_available = "No deployments available for selected model" + no_deployments_with_tag_routing = ( + "Not allowed to access model due to tags configuration" + ) class AllowedFailsPolicy(BaseModel):