router - refactor to tag based routing

This commit is contained in:
Ishaan Jaff 2024-07-18 19:22:09 -07:00 committed by Krrish Dholakia
parent 38c50e674e
commit ad46e6a61f
5 changed files with 81 additions and 76 deletions

View file

@ -47,12 +47,12 @@ from litellm.assistants.main import AssistantDeleted
from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
from litellm.router_utils.client_initalization_utils import ( from litellm.router_utils.client_initalization_utils import (
set_client, set_client,
should_initialize_sync_client, should_initialize_sync_client,
@ -4482,8 +4482,8 @@ class Router:
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )
# check free / paid tier for each deployment # check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tier( healthy_deployments = await get_deployments_for_tag(
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
) )

View file

@ -1,69 +0,0 @@
"""
Use this to route requests between free and paid tiers
"""
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from litellm._logging import verbose_logger
from litellm.types.router import DeploymentTypedDict
class ModelInfo(TypedDict):
tier: Literal["free", "paid"]
class Deployment(TypedDict):
model_info: ModelInfo
async def get_deployments_for_tier(
request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
):
"""
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
"""
if request_kwargs is None:
verbose_logger.debug(
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
if "metadata" in request_kwargs:
metadata = request_kwargs["metadata"]
if "tier" in metadata:
selected_tier: Literal["free", "paid"] = metadata["tier"]
if healthy_deployments is None:
return None
if selected_tier == "free":
# get all deployments where model_info has tier = free
free_deployments: List[Any] = []
verbose_logger.debug(
"Getting deployments in free tier, all_deployments: %s",
healthy_deployments,
)
for deployment in healthy_deployments:
typed_deployment = cast(Deployment, deployment)
if typed_deployment["model_info"]["tier"] == "free":
free_deployments.append(deployment)
verbose_logger.debug("free_deployments: %s", free_deployments)
return free_deployments
elif selected_tier == "paid":
# get all deployments where model_info has tier = paid
paid_deployments: List[Any] = []
for deployment in healthy_deployments:
typed_deployment = cast(Deployment, deployment)
if typed_deployment["model_info"]["tier"] == "paid":
paid_deployments.append(deployment)
verbose_logger.debug("paid_deployments: %s", paid_deployments)
return paid_deployments
verbose_logger.debug(
"no tier found in metadata, returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments

View file

@ -0,0 +1,68 @@
"""
Use this to route requests between free and paid tiers
"""
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from litellm._logging import verbose_logger
from litellm.types.router import DeploymentTypedDict
async def get_deployments_for_tag(
request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
):
"""
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
"""
if request_kwargs is None:
verbose_logger.debug(
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
if healthy_deployments is None:
verbose_logger.debug(
"get_deployments_for_tier: healthy_deployments is None returning healthy_deployments"
)
return healthy_deployments
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
if "metadata" in request_kwargs:
metadata = request_kwargs["metadata"]
request_tags = metadata.get("tags")
new_healthy_deployments = []
if request_tags:
verbose_logger.debug("parameter routing: router_keys: %s", request_tags)
# example this can be router_keys=["free", "custom"]
# get all deployments that have a superset of these router keys
for deployment in healthy_deployments:
deployment_litellm_params = deployment.get("litellm_params")
deployment_tags = deployment_litellm_params.get("tags")
verbose_logger.debug(
"deployment: %s, deployment_router_keys: %s",
deployment,
deployment_tags,
)
if deployment_tags is None:
continue
if set(request_tags).issubset(set(deployment_tags)):
verbose_logger.debug(
"adding deployment with tags: %s, request tags: %s",
deployment_tags,
request_tags,
)
new_healthy_deployments.append(deployment)
return new_healthy_deployments
verbose_logger.debug(
"no tier found in metadata, returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments

View file

@ -45,16 +45,18 @@ async def test_router_free_paid_tier():
"litellm_params": { "litellm_params": {
"model": "gpt-4o", "model": "gpt-4o",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/", "api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"tags": ["free"],
}, },
"model_info": {"tier": "paid", "id": "very-expensive-model"}, "model_info": {"id": "very-cheap-model"},
}, },
{ {
"model_name": "gpt-4", "model_name": "gpt-4",
"litellm_params": { "litellm_params": {
"model": "gpt-4o-mini", "model": "gpt-4o-mini",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/", "api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"tags": ["paid"],
}, },
"model_info": {"tier": "free", "id": "very-cheap-model"}, "model_info": {"id": "very-expensive-model"},
}, },
] ]
) )
@ -64,7 +66,7 @@ async def test_router_free_paid_tier():
response = await router.acompletion( response = await router.acompletion(
model="gpt-4", model="gpt-4",
messages=[{"role": "user", "content": "Tell me a joke."}], messages=[{"role": "user", "content": "Tell me a joke."}],
metadata={"tier": "free"}, metadata={"tags": ["free"]},
) )
print("Response: ", response) print("Response: ", response)
@ -79,7 +81,7 @@ async def test_router_free_paid_tier():
response = await router.acompletion( response = await router.acompletion(
model="gpt-4", model="gpt-4",
messages=[{"role": "user", "content": "Tell me a joke."}], messages=[{"role": "user", "content": "Tell me a joke."}],
metadata={"tier": "paid"}, metadata={"tags": ["paid"]},
) )
print("Response: ", response) print("Response: ", response)

View file

@ -325,6 +325,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
## MOCK RESPONSES ## ## MOCK RESPONSES ##
mock_response: Optional[Union[str, ModelResponse, Exception]] mock_response: Optional[Union[str, ModelResponse, Exception]]
# routing params
# use this for tag-based routing
tags: Optional[List[str]]
class DeploymentTypedDict(TypedDict): class DeploymentTypedDict(TypedDict):
model_name: str model_name: str