diff --git a/litellm/router_strategy/free_paid_tiers.py b/litellm/router_strategy/free_paid_tiers.py new file mode 100644 index 000000000..4328bd84c --- /dev/null +++ b/litellm/router_strategy/free_paid_tiers.py @@ -0,0 +1,64 @@ +""" +Use this to route requests between free and paid tiers +""" + +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast + +from litellm._logging import verbose_logger +from litellm.types.router import DeploymentTypedDict + + +class ModelInfo(TypedDict): + tier: Literal["free", "paid"] + + +class Deployment(TypedDict): + model_info: ModelInfo + + +async def get_deployments_for_tier( + request_kwargs: dict, + healthy_deployments: Optional[ + Union[List[DeploymentTypedDict], List[Dict[str, Any]]] + ] = None, +): + """ + if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models + """ + verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata")) + if "metadata" in request_kwargs: + metadata = request_kwargs["metadata"] + if "tier" in metadata: + selected_tier: Literal["free", "paid"] = metadata["tier"] + if healthy_deployments is None: + return None + + if selected_tier == "free": + # get all deployments where model_info has tier = free + free_deployments: List[Any] = [] + verbose_logger.debug( + "Getting deployments in free tier, all_deployments: %s", + healthy_deployments, + ) + for deployment in healthy_deployments: + typed_deployment = cast(Deployment, deployment) + if typed_deployment["model_info"]["tier"] == "free": + free_deployments.append(deployment) + verbose_logger.debug("free_deployments: %s", free_deployments) + return free_deployments + + elif selected_tier == "paid": + # get all deployments where model_info has tier = paid + paid_deployments: List[Any] = [] + for deployment in healthy_deployments: + typed_deployment = cast(Deployment, deployment) + if typed_deployment["model_info"]["tier"] == "paid": + paid_deployments.append(deployment) + verbose_logger.debug("paid_deployments: %s", paid_deployments) + return paid_deployments + + verbose_logger.debug( + "no tier found in metadata, returning healthy_deployments: %s", + healthy_deployments, + ) + return healthy_deployments