forked from phoenix/litellm-mirror
router - refactor to tag based routing
This commit is contained in:
parent
81c77f33b8
commit
4d0fbfea83
5 changed files with 81 additions and 76 deletions
|
@ -47,12 +47,12 @@ from litellm.assistants.main import AssistantDeleted
|
||||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
|
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
|
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||||
from litellm.router_utils.client_initalization_utils import (
|
from litellm.router_utils.client_initalization_utils import (
|
||||||
set_client,
|
set_client,
|
||||||
should_initialize_sync_client,
|
should_initialize_sync_client,
|
||||||
|
@ -4482,8 +4482,8 @@ class Router:
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check free / paid tier for each deployment
|
# check if user wants to do tag based routing
|
||||||
healthy_deployments = await get_deployments_for_tier(
|
healthy_deployments = await get_deployments_for_tag(
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
"""
|
|
||||||
Use this to route requests between free and paid tiers
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
|
||||||
from litellm.types.router import DeploymentTypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(TypedDict):
|
|
||||||
tier: Literal["free", "paid"]
|
|
||||||
|
|
||||||
|
|
||||||
class Deployment(TypedDict):
|
|
||||||
model_info: ModelInfo
|
|
||||||
|
|
||||||
|
|
||||||
async def get_deployments_for_tier(
|
|
||||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
|
||||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
|
||||||
"""
|
|
||||||
if request_kwargs is None:
|
|
||||||
verbose_logger.debug(
|
|
||||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
|
||||||
healthy_deployments,
|
|
||||||
)
|
|
||||||
return healthy_deployments
|
|
||||||
|
|
||||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
|
||||||
if "metadata" in request_kwargs:
|
|
||||||
metadata = request_kwargs["metadata"]
|
|
||||||
if "tier" in metadata:
|
|
||||||
selected_tier: Literal["free", "paid"] = metadata["tier"]
|
|
||||||
if healthy_deployments is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if selected_tier == "free":
|
|
||||||
# get all deployments where model_info has tier = free
|
|
||||||
free_deployments: List[Any] = []
|
|
||||||
verbose_logger.debug(
|
|
||||||
"Getting deployments in free tier, all_deployments: %s",
|
|
||||||
healthy_deployments,
|
|
||||||
)
|
|
||||||
for deployment in healthy_deployments:
|
|
||||||
typed_deployment = cast(Deployment, deployment)
|
|
||||||
if typed_deployment["model_info"]["tier"] == "free":
|
|
||||||
free_deployments.append(deployment)
|
|
||||||
verbose_logger.debug("free_deployments: %s", free_deployments)
|
|
||||||
return free_deployments
|
|
||||||
|
|
||||||
elif selected_tier == "paid":
|
|
||||||
# get all deployments where model_info has tier = paid
|
|
||||||
paid_deployments: List[Any] = []
|
|
||||||
for deployment in healthy_deployments:
|
|
||||||
typed_deployment = cast(Deployment, deployment)
|
|
||||||
if typed_deployment["model_info"]["tier"] == "paid":
|
|
||||||
paid_deployments.append(deployment)
|
|
||||||
verbose_logger.debug("paid_deployments: %s", paid_deployments)
|
|
||||||
return paid_deployments
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"no tier found in metadata, returning healthy_deployments: %s",
|
|
||||||
healthy_deployments,
|
|
||||||
)
|
|
||||||
return healthy_deployments
|
|
68
litellm/router_strategy/tag_based_routing.py
Normal file
68
litellm/router_strategy/tag_based_routing.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
"""
|
||||||
|
Use this to route requests between free and paid tiers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.router import DeploymentTypedDict
|
||||||
|
|
||||||
|
|
||||||
|
async def get_deployments_for_tag(
|
||||||
|
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||||
|
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||||
|
"""
|
||||||
|
if request_kwargs is None:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||||
|
healthy_deployments,
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
if healthy_deployments is None:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"get_deployments_for_tier: healthy_deployments is None returning healthy_deployments"
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||||
|
if "metadata" in request_kwargs:
|
||||||
|
metadata = request_kwargs["metadata"]
|
||||||
|
request_tags = metadata.get("tags")
|
||||||
|
|
||||||
|
new_healthy_deployments = []
|
||||||
|
if request_tags:
|
||||||
|
verbose_logger.debug("parameter routing: router_keys: %s", request_tags)
|
||||||
|
# example this can be router_keys=["free", "custom"]
|
||||||
|
# get all deployments that have a superset of these router keys
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
deployment_litellm_params = deployment.get("litellm_params")
|
||||||
|
deployment_tags = deployment_litellm_params.get("tags")
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"deployment: %s, deployment_router_keys: %s",
|
||||||
|
deployment,
|
||||||
|
deployment_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
if deployment_tags is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if set(request_tags).issubset(set(deployment_tags)):
|
||||||
|
verbose_logger.debug(
|
||||||
|
"adding deployment with tags: %s, request tags: %s",
|
||||||
|
deployment_tags,
|
||||||
|
request_tags,
|
||||||
|
)
|
||||||
|
new_healthy_deployments.append(deployment)
|
||||||
|
|
||||||
|
return new_healthy_deployments
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"no tier found in metadata, returning healthy_deployments: %s",
|
||||||
|
healthy_deployments,
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
|
@ -45,16 +45,18 @@ async def test_router_free_paid_tier():
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
"tags": ["free"],
|
||||||
},
|
},
|
||||||
"model_info": {"tier": "paid", "id": "very-expensive-model"},
|
"model_info": {"id": "very-cheap-model"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
"tags": ["paid"],
|
||||||
},
|
},
|
||||||
"model_info": {"tier": "free", "id": "very-cheap-model"},
|
"model_info": {"id": "very-expensive-model"},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -64,7 +66,7 @@ async def test_router_free_paid_tier():
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
metadata={"tier": "free"},
|
metadata={"tags": ["free"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Response: ", response)
|
print("Response: ", response)
|
||||||
|
@ -79,7 +81,7 @@ async def test_router_free_paid_tier():
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
metadata={"tier": "paid"},
|
metadata={"tags": ["paid"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Response: ", response)
|
print("Response: ", response)
|
|
@ -325,6 +325,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
## MOCK RESPONSES ##
|
## MOCK RESPONSES ##
|
||||||
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
||||||
|
|
||||||
|
# routing params
|
||||||
|
# use this for tag-based routing
|
||||||
|
tags: Optional[List[str]]
|
||||||
|
|
||||||
|
|
||||||
class DeploymentTypedDict(TypedDict):
|
class DeploymentTypedDict(TypedDict):
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue