forked from phoenix/litellm-mirror
control using enable_tag_filtering
This commit is contained in:
parent
e298515034
commit
52d0f6a808
3 changed files with 15 additions and 61 deletions
|
@ -145,6 +145,7 @@ class Router:
|
||||||
content_policy_fallbacks: List = [],
|
content_policy_fallbacks: List = [],
|
||||||
model_group_alias: Optional[dict] = {},
|
model_group_alias: Optional[dict] = {},
|
||||||
enable_pre_call_checks: bool = False,
|
enable_pre_call_checks: bool = False,
|
||||||
|
enable_tag_filtering: bool = False,
|
||||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||||
retry_policy: Optional[
|
retry_policy: Optional[
|
||||||
RetryPolicy
|
RetryPolicy
|
||||||
|
@ -246,6 +247,7 @@ class Router:
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.enable_pre_call_checks = enable_pre_call_checks
|
self.enable_pre_call_checks = enable_pre_call_checks
|
||||||
|
self.enable_tag_filtering = enable_tag_filtering
|
||||||
if self.set_verbose == True:
|
if self.set_verbose == True:
|
||||||
if debug_level == "INFO":
|
if debug_level == "INFO":
|
||||||
verbose_router_logger.setLevel(logging.INFO)
|
verbose_router_logger.setLevel(logging.INFO)
|
||||||
|
@ -4484,6 +4486,7 @@ class Router:
|
||||||
|
|
||||||
# check if user wants to do tag based routing
|
# check if user wants to do tag based routing
|
||||||
healthy_deployments = await get_deployments_for_tag(
|
healthy_deployments = await get_deployments_for_tag(
|
||||||
|
llm_router_instance=self,
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,19 +2,30 @@
|
||||||
Use this to route requests between free and paid tiers
|
Use this to route requests between free and paid tiers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.types.router import DeploymentTypedDict
|
from litellm.types.router import DeploymentTypedDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.router import Router as _Router
|
||||||
|
|
||||||
|
LitellmRouter = _Router
|
||||||
|
else:
|
||||||
|
LitellmRouter = Any
|
||||||
|
|
||||||
|
|
||||||
async def get_deployments_for_tag(
|
async def get_deployments_for_tag(
|
||||||
|
llm_router_instance: LitellmRouter,
|
||||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||||
"""
|
"""
|
||||||
|
if llm_router_instance.enable_tag_filtering is not True:
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
if request_kwargs is None:
|
if request_kwargs is None:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
"""
|
|
||||||
Tests litellm pre_call_utils
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import Request
|
|
||||||
from fastapi.routing import APIRoute
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
||||||
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# this file is to test litellm/proxy
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tier", ["free", "paid"])
|
|
||||||
@pytest.mark.asyncio()
|
|
||||||
async def test_adding_key_tier_to_request_metadata(tier):
|
|
||||||
"""
|
|
||||||
Tests if we can add tier: free/paid from key metadata to the request metadata
|
|
||||||
"""
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
|
|
||||||
request = Request(
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"method": "POST",
|
|
||||||
"route": api_route,
|
|
||||||
"path": api_route.path,
|
|
||||||
"headers": [],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
new_data = await add_litellm_data_to_request(
|
|
||||||
data=data,
|
|
||||||
request=request,
|
|
||||||
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
|
|
||||||
proxy_config=ProxyConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("new_data", new_data)
|
|
||||||
|
|
||||||
assert new_data["metadata"]["tier"] == tier
|
|
Loading…
Add table
Add a link
Reference in a new issue