control using enable_tag_filtering

This commit is contained in:
Ishaan Jaff 2024-07-18 19:39:04 -07:00
parent 071091fd8c
commit 08adda7091
3 changed files with 15 additions and 61 deletions

View file

@ -145,6 +145,7 @@ class Router:
content_policy_fallbacks: List = [], content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[dict] = {},
enable_pre_call_checks: bool = False, enable_pre_call_checks: bool = False,
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request retry_after: int = 0, # min time to wait before retrying a failed request
retry_policy: Optional[ retry_policy: Optional[
RetryPolicy RetryPolicy
@ -246,6 +247,7 @@ class Router:
self.set_verbose = set_verbose self.set_verbose = set_verbose
self.debug_level = debug_level self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks self.enable_pre_call_checks = enable_pre_call_checks
self.enable_tag_filtering = enable_tag_filtering
if self.set_verbose == True: if self.set_verbose == True:
if debug_level == "INFO": if debug_level == "INFO":
verbose_router_logger.setLevel(logging.INFO) verbose_router_logger.setLevel(logging.INFO)
@ -4484,6 +4486,7 @@ class Router:
# check if user wants to do tag based routing # check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tag( healthy_deployments = await get_deployments_for_tag(
llm_router_instance=self,
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
) )

View file

@ -2,19 +2,30 @@
Use this to route requests between free and paid tiers Use this to route requests between free and paid tiers
""" """
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.types.router import DeploymentTypedDict from litellm.types.router import DeploymentTypedDict
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
async def get_deployments_for_tag( async def get_deployments_for_tag(
llm_router_instance: LitellmRouter,
request_kwargs: Optional[Dict[Any, Any]] = None, request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None, healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
): ):
""" """
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
""" """
if llm_router_instance.enable_tag_filtering is not True:
return healthy_deployments
if request_kwargs is None: if request_kwargs is None:
verbose_logger.debug( verbose_logger.debug(
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s", "get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",

View file

@ -1,60 +0,0 @@
"""
Tests litellm pre_call_utils
"""
import os
import sys
import traceback
import uuid
from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
load_dotenv()
import io
import os
import time
import pytest
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@pytest.mark.parametrize("tier", ["free", "paid"])
@pytest.mark.asyncio()
async def test_adding_key_tier_to_request_metadata(tier):
"""
Tests if we can add tier: free/paid from key metadata to the request metadata
"""
data = {}
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
request = Request(
{
"type": "http",
"method": "POST",
"route": api_route,
"path": api_route.path,
"headers": [],
}
)
new_data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
proxy_config=ProxyConfig(),
)
print("new_data", new_data)
assert new_data["metadata"]["tier"] == tier