mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
router - refactor to tag based routing
This commit is contained in:
parent
81c77f33b8
commit
4d0fbfea83
5 changed files with 81 additions and 76 deletions
|
@ -47,12 +47,12 @@ from litellm.assistants.main import AssistantDeleted
|
|||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||
from litellm.router_utils.client_initalization_utils import (
|
||||
set_client,
|
||||
should_initialize_sync_client,
|
||||
|
@ -4482,8 +4482,8 @@ class Router:
|
|||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
# check free / paid tier for each deployment
|
||||
healthy_deployments = await get_deployments_for_tier(
|
||||
# check if user wants to do tag based routing
|
||||
healthy_deployments = await get_deployments_for_tag(
|
||||
request_kwargs=request_kwargs,
|
||||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
"""
|
||||
Use this to route requests between free and paid tiers
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
|
||||
|
||||
class ModelInfo(TypedDict):
|
||||
tier: Literal["free", "paid"]
|
||||
|
||||
|
||||
class Deployment(TypedDict):
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
async def get_deployments_for_tier(
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
if "tier" in metadata:
|
||||
selected_tier: Literal["free", "paid"] = metadata["tier"]
|
||||
if healthy_deployments is None:
|
||||
return None
|
||||
|
||||
if selected_tier == "free":
|
||||
# get all deployments where model_info has tier = free
|
||||
free_deployments: List[Any] = []
|
||||
verbose_logger.debug(
|
||||
"Getting deployments in free tier, all_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
for deployment in healthy_deployments:
|
||||
typed_deployment = cast(Deployment, deployment)
|
||||
if typed_deployment["model_info"]["tier"] == "free":
|
||||
free_deployments.append(deployment)
|
||||
verbose_logger.debug("free_deployments: %s", free_deployments)
|
||||
return free_deployments
|
||||
|
||||
elif selected_tier == "paid":
|
||||
# get all deployments where model_info has tier = paid
|
||||
paid_deployments: List[Any] = []
|
||||
for deployment in healthy_deployments:
|
||||
typed_deployment = cast(Deployment, deployment)
|
||||
if typed_deployment["model_info"]["tier"] == "paid":
|
||||
paid_deployments.append(deployment)
|
||||
verbose_logger.debug("paid_deployments: %s", paid_deployments)
|
||||
return paid_deployments
|
||||
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
68
litellm/router_strategy/tag_based_routing.py
Normal file
68
litellm/router_strategy/tag_based_routing.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
"""
|
||||
Use this to route requests between free and paid tiers
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
if healthy_deployments is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: healthy_deployments is None returning healthy_deployments"
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
request_tags = metadata.get("tags")
|
||||
|
||||
new_healthy_deployments = []
|
||||
if request_tags:
|
||||
verbose_logger.debug("parameter routing: router_keys: %s", request_tags)
|
||||
# example this can be router_keys=["free", "custom"]
|
||||
# get all deployments that have a superset of these router keys
|
||||
for deployment in healthy_deployments:
|
||||
deployment_litellm_params = deployment.get("litellm_params")
|
||||
deployment_tags = deployment_litellm_params.get("tags")
|
||||
|
||||
verbose_logger.debug(
|
||||
"deployment: %s, deployment_router_keys: %s",
|
||||
deployment,
|
||||
deployment_tags,
|
||||
)
|
||||
|
||||
if deployment_tags is None:
|
||||
continue
|
||||
|
||||
if set(request_tags).issubset(set(deployment_tags)):
|
||||
verbose_logger.debug(
|
||||
"adding deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
new_healthy_deployments.append(deployment)
|
||||
|
||||
return new_healthy_deployments
|
||||
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
|
@ -45,16 +45,18 @@ async def test_router_free_paid_tier():
|
|||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
"tags": ["free"],
|
||||
},
|
||||
"model_info": {"tier": "paid", "id": "very-expensive-model"},
|
||||
"model_info": {"id": "very-cheap-model"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
"tags": ["paid"],
|
||||
},
|
||||
"model_info": {"tier": "free", "id": "very-cheap-model"},
|
||||
"model_info": {"id": "very-expensive-model"},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
@ -64,7 +66,7 @@ async def test_router_free_paid_tier():
|
|||
response = await router.acompletion(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
metadata={"tier": "free"},
|
||||
metadata={"tags": ["free"]},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
||||
|
@ -79,7 +81,7 @@ async def test_router_free_paid_tier():
|
|||
response = await router.acompletion(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
metadata={"tier": "paid"},
|
||||
metadata={"tags": ["paid"]},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
|
@ -325,6 +325,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
## MOCK RESPONSES ##
|
||||
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
||||
|
||||
# routing params
|
||||
# use this for tag-based routing
|
||||
tags: Optional[List[str]]
|
||||
|
||||
|
||||
class DeploymentTypedDict(TypedDict):
|
||||
model_name: str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue