diff --git a/litellm/router.py b/litellm/router.py index 44c02f126..0e693e188 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -145,6 +145,7 @@ class Router: content_policy_fallbacks: List = [], model_group_alias: Optional[dict] = {}, enable_pre_call_checks: bool = False, + enable_tag_filtering: bool = False, retry_after: int = 0, # min time to wait before retrying a failed request retry_policy: Optional[ RetryPolicy @@ -246,6 +247,7 @@ class Router: self.set_verbose = set_verbose self.debug_level = debug_level self.enable_pre_call_checks = enable_pre_call_checks + self.enable_tag_filtering = enable_tag_filtering if self.set_verbose == True: if debug_level == "INFO": verbose_router_logger.setLevel(logging.INFO) @@ -4484,6 +4486,7 @@ class Router: # check if user wants to do tag based routing healthy_deployments = await get_deployments_for_tag( + llm_router_instance=self, request_kwargs=request_kwargs, healthy_deployments=healthy_deployments, ) diff --git a/litellm/router_strategy/tag_based_routing.py b/litellm/router_strategy/tag_based_routing.py index 11bad19a3..2dbc5cb93 100644 --- a/litellm/router_strategy/tag_based_routing.py +++ b/litellm/router_strategy/tag_based_routing.py @@ -2,19 +2,30 @@ Use this to route requests between free and paid tiers """ -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union from litellm._logging import verbose_logger from litellm.types.router import DeploymentTypedDict +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + async def get_deployments_for_tag( + llm_router_instance: LitellmRouter, request_kwargs: Optional[Dict[Any, Any]] = None, healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None, ): """ if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models """ + if llm_router_instance.enable_tag_filtering is not True: + return healthy_deployments + if request_kwargs is None: verbose_logger.debug( "get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s", diff --git a/litellm/tests/test_litellm_pre_call_utils.py b/litellm/tests/test_litellm_pre_call_utils.py deleted file mode 100644 index 7f56d693d..000000000 --- a/litellm/tests/test_litellm_pre_call_utils.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Tests litellm pre_call_utils -""" - -import os -import sys -import traceback -import uuid -from datetime import datetime - -from dotenv import load_dotenv -from fastapi import Request -from fastapi.routing import APIRoute - -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request -from litellm.proxy.proxy_server import ProxyConfig, chat_completion - -load_dotenv() -import io -import os -import time - -import pytest - -# this file is to test litellm/proxy - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path - - -@pytest.mark.parametrize("tier", ["free", "paid"]) -@pytest.mark.asyncio() -async def test_adding_key_tier_to_request_metadata(tier): - """ - Tests if we can add tier: free/paid from key metadata to the request metadata - """ - data = {} - - api_route = APIRoute(path="/chat/completions", endpoint=chat_completion) - request = Request( - { - "type": "http", - "method": "POST", - "route": api_route, - "path": api_route.path, - "headers": [], - } - ) - new_data = await add_litellm_data_to_request( - data=data, - request=request, - user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}), - proxy_config=ProxyConfig(), - ) - - print("new_data", new_data) - - assert new_data["metadata"]["tier"] == tier