forked from phoenix/litellm-mirror
control using enable_tag_filtering
This commit is contained in:
parent
e298515034
commit
52d0f6a808
3 changed files with 15 additions and 61 deletions
|
@ -145,6 +145,7 @@ class Router:
|
|||
content_policy_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
enable_tag_filtering: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
retry_policy: Optional[
|
||||
RetryPolicy
|
||||
|
@ -246,6 +247,7 @@ class Router:
|
|||
self.set_verbose = set_verbose
|
||||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
self.enable_tag_filtering = enable_tag_filtering
|
||||
if self.set_verbose == True:
|
||||
if debug_level == "INFO":
|
||||
verbose_router_logger.setLevel(logging.INFO)
|
||||
|
@ -4484,6 +4486,7 @@ class Router:
|
|||
|
||||
# check if user wants to do tag based routing
|
||||
healthy_deployments = await get_deployments_for_tag(
|
||||
llm_router_instance=self,
|
||||
request_kwargs=request_kwargs,
|
||||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
|
|
@ -2,19 +2,30 @@
|
|||
Use this to route requests between free and paid tiers
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||
"""
|
||||
if llm_router_instance.enable_tag_filtering is not True:
|
||||
return healthy_deployments
|
||||
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
"""
|
||||
Tests litellm pre_call_utils
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tier", ["free", "paid"])
|
||||
@pytest.mark.asyncio()
|
||||
async def test_adding_key_tier_to_request_metadata(tier):
|
||||
"""
|
||||
Tests if we can add tier: free/paid from key metadata to the request metadata
|
||||
"""
|
||||
data = {}
|
||||
|
||||
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"route": api_route,
|
||||
"path": api_route.path,
|
||||
"headers": [],
|
||||
}
|
||||
)
|
||||
new_data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
|
||||
proxy_config=ProxyConfig(),
|
||||
)
|
||||
|
||||
print("new_data", new_data)
|
||||
|
||||
assert new_data["metadata"]["tier"] == tier
|
Loading…
Add table
Add a link
Reference in a new issue