control using enable_tag_filtering

This commit is contained in:
Ishaan Jaff 2024-07-18 19:39:04 -07:00
parent 071091fd8c
commit 08adda7091
3 changed files with 15 additions and 61 deletions

View file

@ -145,6 +145,7 @@ class Router:
content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False,
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[
RetryPolicy
@ -246,6 +247,7 @@ class Router:
self.set_verbose = set_verbose
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
self.enable_tag_filtering = enable_tag_filtering
if self.set_verbose == True:
if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO)
@ -4484,6 +4486,7 @@ class Router:
# check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tag(
llm_router_instance=self,
request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments,
)

View file

@ -2,19 +2,30 @@
Use this to route requests between free and paid tiers
"""
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
from litellm._logging import verbose_logger
from litellm.types.router import DeploymentTypedDict
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
async def get_deployments_for_tag(
llm_router_instance: LitellmRouter,
request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
):
"""
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
"""
if llm_router_instance.enable_tag_filtering is not True:
return healthy_deployments
if request_kwargs is None:
verbose_logger.debug(
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",

View file

@ -1,60 +0,0 @@
"""
Tests litellm pre_call_utils
"""
import os
import sys
import traceback
import uuid
from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
load_dotenv()
import io
import os
import time
import pytest
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@pytest.mark.parametrize("tier", ["free", "paid"])
@pytest.mark.asyncio()
async def test_adding_key_tier_to_request_metadata(tier):
"""
Tests if we can add tier: free/paid from key metadata to the request metadata
"""
data = {}
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
request = Request(
{
"type": "http",
"method": "POST",
"route": api_route,
"path": api_route.path,
"headers": [],
}
)
new_data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
proxy_config=ProxyConfig(),
)
print("new_data", new_data)
assert new_data["metadata"]["tier"] == tier