From 0e70b5df14c5dc25c29a2424923cb7d801a1500d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 18 Jul 2024 17:09:42 -0700 Subject: [PATCH] router - use free paid tier routing --- litellm/router.py | 7 ++ litellm/router_strategy/free_paid_tiers.py | 13 +++- litellm/tests/test_router_tiers.py | 90 ++++++++++++++++++++++ 3 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 litellm/tests/test_router_tiers.py diff --git a/litellm/router.py b/litellm/router.py index 2f72b8142..487d5fd6a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -47,6 +47,7 @@ from litellm.assistants.main import AssistantDeleted from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler @@ -4481,6 +4482,12 @@ class Router: request_kwargs=request_kwargs, ) + # check free / paid tier for each deployment + healthy_deployments = await get_deployments_for_tier( + request_kwargs=request_kwargs, + healthy_deployments=healthy_deployments, + ) + if len(healthy_deployments) == 0: if _allowed_model_region is None: _allowed_model_region = "n/a" diff --git a/litellm/router_strategy/free_paid_tiers.py b/litellm/router_strategy/free_paid_tiers.py index 4328bd84c..82e38b4f5 100644 --- a/litellm/router_strategy/free_paid_tiers.py +++ b/litellm/router_strategy/free_paid_tiers.py @@ -17,14 +17,19 @@ class Deployment(TypedDict): async def get_deployments_for_tier( - request_kwargs: dict, - healthy_deployments: Optional[ - Union[List[DeploymentTypedDict], List[Dict[str, Any]]] - ] = None, + request_kwargs: Optional[Dict[Any, Any]] = None, + healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None, ): """ if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models """ + if request_kwargs is None: + verbose_logger.debug( + "get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s", + healthy_deployments, + ) + return healthy_deployments + verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata")) if "metadata" in request_kwargs: metadata = request_kwargs["metadata"] diff --git a/litellm/tests/test_router_tiers.py b/litellm/tests/test_router_tiers.py new file mode 100644 index 000000000..54e67ded3 --- /dev/null +++ b/litellm/tests/test_router_tiers.py @@ -0,0 +1,90 @@ +#### What this tests #### +# This tests litellm router + +import asyncio +import os +import sys +import time +import traceback + +import openai +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import logging +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +from dotenv import load_dotenv + +import litellm +from litellm import Router +from litellm._logging import verbose_logger + +verbose_logger.setLevel(logging.DEBUG) + + +load_dotenv() + + +@pytest.mark.asyncio() +async def test_router_free_paid_tier(): + """ + Pass list of orgs in 1 model definition, + expect a unique deployment for each to be created + """ + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4o", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + }, + "model_info": {"tier": "paid", "id": "very-expensive-model"}, + }, + { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4o-mini", + "api_base": "https://exampleopenaiendpoint-production.up.railway.app/", + }, + "model_info": {"tier": "free", "id": "very-cheap-model"}, + }, + ] + ) + + for _ in range(5): + # this should pick model with id == very-cheap-model + response = await router.acompletion( + model="gpt-4", + messages=[{"role": "user", "content": "Tell me a joke."}], + metadata={"tier": "free"}, + ) + + print("Response: ", response) + + response_extra_info = response._hidden_params + print("response_extra_info: ", response_extra_info) + + assert response_extra_info["model_id"] == "very-cheap-model" + + for _ in range(5): + # this should pick model with id == very-cheap-model + response = await router.acompletion( + model="gpt-4", + messages=[{"role": "user", "content": "Tell me a joke."}], + metadata={"tier": "paid"}, + ) + + print("Response: ", response) + + response_extra_info = response._hidden_params + print("response_extra_info: ", response_extra_info) + + assert response_extra_info["model_id"] == "very-expensive-model"