forked from phoenix/litellm-mirror
router - use free paid tier routing
This commit is contained in:
parent
229b7a6493
commit
0e70b5df14
3 changed files with 106 additions and 4 deletions
|
@ -47,6 +47,7 @@ from litellm.assistants.main import AssistantDeleted
|
||||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
|
@ -4481,6 +4482,12 @@ class Router:
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check free / paid tier for each deployment
|
||||||
|
healthy_deployments = await get_deployments_for_tier(
|
||||||
|
request_kwargs=request_kwargs,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
)
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
if _allowed_model_region is None:
|
if _allowed_model_region is None:
|
||||||
_allowed_model_region = "n/a"
|
_allowed_model_region = "n/a"
|
||||||
|
|
|
@ -17,14 +17,19 @@ class Deployment(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
async def get_deployments_for_tier(
|
async def get_deployments_for_tier(
|
||||||
request_kwargs: dict,
|
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||||
healthy_deployments: Optional[
|
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
|
||||||
] = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||||
"""
|
"""
|
||||||
|
if request_kwargs is None:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||||
|
healthy_deployments,
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||||
if "metadata" in request_kwargs:
|
if "metadata" in request_kwargs:
|
||||||
metadata = request_kwargs["metadata"]
|
metadata = request_kwargs["metadata"]
|
||||||
|
|
90
litellm/tests/test_router_tiers.py
Normal file
90
litellm/tests/test_router_tiers.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests litellm router
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_router_free_paid_tier():
|
||||||
|
"""
|
||||||
|
Pass list of orgs in 1 model definition,
|
||||||
|
expect a unique deployment for each to be created
|
||||||
|
"""
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
},
|
||||||
|
"model_info": {"tier": "paid", "id": "very-expensive-model"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
},
|
||||||
|
"model_info": {"tier": "free", "id": "very-cheap-model"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
# this should pick model with id == very-cheap-model
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
|
metadata={"tier": "free"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response: ", response)
|
||||||
|
|
||||||
|
response_extra_info = response._hidden_params
|
||||||
|
print("response_extra_info: ", response_extra_info)
|
||||||
|
|
||||||
|
assert response_extra_info["model_id"] == "very-cheap-model"
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
# this should pick model with id == very-cheap-model
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
|
metadata={"tier": "paid"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response: ", response)
|
||||||
|
|
||||||
|
response_extra_info = response._hidden_params
|
||||||
|
print("response_extra_info: ", response_extra_info)
|
||||||
|
|
||||||
|
assert response_extra_info["model_id"] == "very-expensive-model"
|
Loading…
Add table
Add a link
Reference in a new issue