diff --git a/litellm/router.py b/litellm/router.py index 487d5fd6a..44c02f126 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -47,12 +47,12 @@ from litellm.assistants.main import AssistantDeleted from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc -from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_strategy.tag_based_routing import get_deployments_for_tag from litellm.router_utils.client_initalization_utils import ( set_client, should_initialize_sync_client, @@ -4482,8 +4482,8 @@ class Router: request_kwargs=request_kwargs, ) - # check free / paid tier for each deployment - healthy_deployments = await get_deployments_for_tier( + # check if user wants to do tag based routing + healthy_deployments = await get_deployments_for_tag( request_kwargs=request_kwargs, healthy_deployments=healthy_deployments, ) diff --git a/litellm/router_strategy/free_paid_tiers.py b/litellm/router_strategy/free_paid_tiers.py deleted file mode 100644 index 82e38b4f5..000000000 --- a/litellm/router_strategy/free_paid_tiers.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Use this to route requests between free and paid tiers -""" - -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast - -from litellm._logging import verbose_logger -from litellm.types.router import DeploymentTypedDict - - -class ModelInfo(TypedDict): - tier: Literal["free", "paid"] - - -class Deployment(TypedDict): - model_info: ModelInfo - - -async def get_deployments_for_tier( - request_kwargs: Optional[Dict[Any, Any]] = None, - healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None, -): - """ - if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models - """ - if request_kwargs is None: - verbose_logger.debug( - "get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s", - healthy_deployments, - ) - return healthy_deployments - - verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata")) - if "metadata" in request_kwargs: - metadata = request_kwargs["metadata"] - if "tier" in metadata: - selected_tier: Literal["free", "paid"] = metadata["tier"] - if healthy_deployments is None: - return None - - if selected_tier == "free": - # get all deployments where model_info has tier = free - free_deployments: List[Any] = [] - verbose_logger.debug( - "Getting deployments in free tier, all_deployments: %s", - healthy_deployments, - ) - for deployment in healthy_deployments: - typed_deployment = cast(Deployment, deployment) - if typed_deployment["model_info"]["tier"] == "free": - free_deployments.append(deployment) - verbose_logger.debug("free_deployments: %s", free_deployments) - return free_deployments - - elif selected_tier == "paid": - # get all deployments where model_info has tier = paid - paid_deployments: List[Any] = [] - for deployment in healthy_deployments: - typed_deployment = cast(Deployment, deployment) - if typed_deployment["model_info"]["tier"] == "paid": - paid_deployments.append(deployment) - verbose_logger.debug("paid_deployments: %s", paid_deployments) - return paid_deployments - - verbose_logger.debug( - "no tier found in metadata, returning healthy_deployments: %s", - healthy_deployments, - ) - return healthy_deployments diff --git a/litellm/router_strategy/tag_based_routing.py b/litellm/router_strategy/tag_based_routing.py new file mode 100644 index 000000000..11bad19a3 --- /dev/null +++ b/litellm/router_strategy/tag_based_routing.py @@ -0,0 +1,68 @@ +""" +Use this to route requests between free and paid tiers +""" + +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast + +from litellm._logging import verbose_logger +from litellm.types.router import DeploymentTypedDict + + +async def get_deployments_for_tag( + request_kwargs: Optional[Dict[Any, Any]] = None, + healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None, +): + """ + if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models + """ + if request_kwargs is None: + verbose_logger.debug( + "get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s", + healthy_deployments, + ) + return healthy_deployments + + if healthy_deployments is None: + verbose_logger.debug( + "get_deployments_for_tier: healthy_deployments is None returning healthy_deployments" + ) + return healthy_deployments + + verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata")) + if "metadata" in request_kwargs: + metadata = request_kwargs["metadata"] + request_tags = metadata.get("tags") + + new_healthy_deployments = [] + if request_tags: + verbose_logger.debug("parameter routing: router_keys: %s", request_tags) + # example this can be router_keys=["free", "custom"] + # get all deployments that have a superset of these router keys + for deployment in healthy_deployments: + deployment_litellm_params = deployment.get("litellm_params") + deployment_tags = deployment_litellm_params.get("tags") + + verbose_logger.debug( + "deployment: %s, deployment_router_keys: %s", + deployment, + deployment_tags, + ) + + if deployment_tags is None: + continue + + if set(request_tags).issubset(set(deployment_tags)): + verbose_logger.debug( + "adding deployment with tags: %s, request tags: %s", + deployment_tags, + request_tags, + ) + new_healthy_deployments.append(deployment) + + return new_healthy_deployments + + verbose_logger.debug( + "no tier found in metadata, returning healthy_deployments: %s", + healthy_deployments, + ) + return healthy_deployments diff --git a/litellm/tests/test_router_tiers.py b/litellm/tests/test_router_tag_routing.py similarity index 89% rename from litellm/tests/test_router_tiers.py rename to litellm/tests/test_router_tag_routing.py index 54e67ded3..feb67c0e9 100644 --- a/litellm/tests/test_router_tiers.py +++ b/litellm/tests/test_router_tag_routing.py @@ -45,16 +45,18 @@ async def test_router_free_paid_tier(): "litellm_params": { "model": "gpt-4o", "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + "tags": ["free"], }, - "model_info": {"tier": "paid", "id": "very-expensive-model"}, + "model_info": {"id": "very-cheap-model"}, }, { "model_name": "gpt-4", "litellm_params": { "model": "gpt-4o-mini", "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + "tags": ["paid"], }, - "model_info": {"tier": "free", "id": "very-cheap-model"}, + "model_info": {"id": "very-expensive-model"}, }, ] ) @@ -64,7 +66,7 @@ async def test_router_free_paid_tier(): response = await router.acompletion( model="gpt-4", messages=[{"role": "user", "content": "Tell me a joke."}], - metadata={"tier": "free"}, + metadata={"tags": ["free"]}, ) print("Response: ", response) @@ -79,7 +81,7 @@ async def test_router_free_paid_tier(): response = await router.acompletion( model="gpt-4", messages=[{"role": "user", "content": "Tell me a joke."}], - metadata={"tier": "paid"}, + metadata={"tags": ["paid"]}, ) print("Response: ", response) diff --git a/litellm/types/router.py b/litellm/types/router.py index df9947c26..78dfbc4c1 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -325,6 +325,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): ## MOCK RESPONSES ## mock_response: Optional[Union[str, ModelResponse, Exception]] + # routing params + # use this for tag-based routing + tags: Optional[List[str]] + class DeploymentTypedDict(TypedDict): model_name: str