forked from phoenix/litellm-mirror
Merge pull request #4789 from BerriAI/litellm_router_refactor
[Feat-Router] - Tag based routing
This commit is contained in:
commit
f04397e19a
12 changed files with 256 additions and 258 deletions
|
@ -1,102 +0,0 @@
|
|||
# 💸 Free, Paid Tier Routing
|
||||
|
||||
Route Virtual Keys on `free tier` to cheaper models
|
||||
|
||||
### 1. Define free, paid tier models on config.yaml
|
||||
|
||||
:::info
|
||||
Requests with `model=gpt-4` will be routed to either `openai/fake` or `openai/gpt-4o` depending on which tier the virtual key is on
|
||||
:::
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model_info:
|
||||
tier: free # 👈 Key Change - set `tier to paid or free`
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
tier: paid # 👈 Key Change - set `tier to paid or free`
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
### 2. Create Virtual Keys with pricing `tier=free`
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"metadata": {"tier": "free"}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. Make Request with Key on `Free Tier`
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-inxzoSurQsjog9gPrVOCcA" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude gm!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
If this worked as expected then `x-litellm-model-api-base` should be `https://exampleopenaiendpoint-production.up.railway.app/` in the response headers
|
||||
|
||||
```shell
|
||||
x-litellm-model-api-base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
{"id":"chatcmpl-657b750f581240c1908679ed94b31bfe","choices":[{"finish_reason":"stop","index":0,"message":{"content":"\n\nHello there, how may I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1677652288,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":"fp_44709d6fcb","usage":{"completion_tokens":12,"prompt_tokens":9,"total_tokens":21}}%
|
||||
```
|
||||
|
||||
|
||||
### 4. Create Virtual Keys with pricing `tier=paid`
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"metadata": {"tier": "paid"}
|
||||
}'
|
||||
```
|
||||
|
||||
### 5. Make Request with Key on `Paid Tier`
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-mnJoeSc6jFjzZr256q-iqA" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude gm!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
If this worked as expected then `x-litellm-model-api-base` should be `https://api.openai.com` in the response headers
|
||||
|
||||
```shell
|
||||
x-litellm-model-api-base: https://api.openai.com
|
||||
|
||||
{"id":"chatcmpl-9mW75EbJCgwmLcO0M5DmwxpiBgWdc","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Good morning! How can I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1721350215,"model":"gpt-4o-2024-05-13","object":"chat.completion","system_fingerprint":"fp_c4e5b6fa31","usage":{"completion_tokens":10,"prompt_tokens":12,"total_tokens":22}}
|
||||
```
|
133
docs/my-website/docs/proxy/tag_routing.md
Normal file
133
docs/my-website/docs/proxy/tag_routing.md
Normal file
|
@ -0,0 +1,133 @@
|
|||
# 💸 Tag Based Routing
|
||||
|
||||
Route requests based on tags.
|
||||
This is useful for implementing free / paid tiers for users
|
||||
|
||||
### 1. Define tags on config.yaml
|
||||
|
||||
- A request with `tags=["free"]` will get routed to `openai/fake`
|
||||
- A request with `tags=["paid"]` will get routed to `openai/gpt-4o`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
tags: ["free"] # 👈 Key Change
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
tags: ["paid"] # 👈 Key Change
|
||||
|
||||
router_settings:
|
||||
enable_tag_filtering: True # 👈 Key Change
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
### 2. Make Request with `tags=["free"]`
|
||||
|
||||
This request includes "tags": ["free"], which routes it to `openai/fake`
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude gm!"}
|
||||
],
|
||||
"tags": ["free"]
|
||||
}'
|
||||
```
|
||||
**Expected Response**
|
||||
|
||||
Expect to see the following response header when this works
|
||||
```shell
|
||||
x-litellm-model-api-base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
```
|
||||
|
||||
Response
|
||||
```shell
|
||||
{
|
||||
"id": "chatcmpl-33c534e3d70148218e2d62496b81270b",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "\n\nHello there, how may I assist you today?",
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"function_call": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1677652288,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"usage": {
|
||||
"completion_tokens": 12,
|
||||
"prompt_tokens": 9,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### 3. Make Request with `tags=["paid"]`
|
||||
|
||||
This request includes "tags": ["paid"], which routes it to `openai/gpt-4`
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude gm!"}
|
||||
],
|
||||
"tags": ["paid"]
|
||||
}'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
Expect to see the following response header when this works
|
||||
```shell
|
||||
x-litellm-model-api-base: https://api.openai.com
|
||||
```
|
||||
|
||||
Response
|
||||
```shell
|
||||
{
|
||||
"id": "chatcmpl-9maCcqQYTqdJrtvfakIawMOIUbEZx",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "Good morning! How can I assist you today?",
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"function_call": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1721365934,
|
||||
"model": "gpt-4o-2024-05-13",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "fp_c4e5b6fa31",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 12,
|
||||
"total_tokens": 22
|
||||
}
|
||||
}
|
||||
```
|
|
@ -44,7 +44,7 @@ const sidebars = {
|
|||
"proxy/cost_tracking",
|
||||
"proxy/self_serve",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/free_paid_tier",
|
||||
"proxy/tag_routing",
|
||||
"proxy/users",
|
||||
"proxy/team_budgets",
|
||||
"proxy/customers",
|
||||
|
|
|
@ -735,6 +735,7 @@ def completion(
|
|||
]
|
||||
litellm_params = [
|
||||
"metadata",
|
||||
"tags",
|
||||
"acompletion",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
|
@ -3155,6 +3156,7 @@ def embedding(
|
|||
"allowed_model_region",
|
||||
"model_config",
|
||||
"cooldown_time",
|
||||
"tags",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -4384,6 +4386,8 @@ def transcription(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
tags = kwargs.pop("tags", [])
|
||||
|
||||
drop_params = kwargs.get("drop_params", None)
|
||||
client: Optional[
|
||||
Union[
|
||||
|
@ -4556,6 +4560,7 @@ def speech(
|
|||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
tags = kwargs.pop("tags", [])
|
||||
|
||||
optional_params = {}
|
||||
if response_format is not None:
|
||||
|
|
|
@ -75,7 +75,7 @@ async def add_litellm_data_to_request(
|
|||
dict: The modified data dictionary.
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
from litellm.proxy.proxy_server import llm_router, premium_user
|
||||
|
||||
safe_add_api_version_from_query_params(data, request)
|
||||
|
||||
|
@ -132,15 +132,6 @@ async def add_litellm_data_to_request(
|
|||
for k, v in key_metadata["cache"].items():
|
||||
if k in SupportedCacheControls:
|
||||
data["cache"][k] = v
|
||||
if "tier" in key_metadata:
|
||||
if premium_user is not True:
|
||||
verbose_logger.warning(
|
||||
"Trying to use free/paid tier feature. This will not be applied %s",
|
||||
CommonProxyErrors.not_premium_user.value,
|
||||
)
|
||||
|
||||
# add request tier to metadata
|
||||
data[_metadata_variable_name]["tier"] = key_metadata["tier"]
|
||||
|
||||
# Team spend, budget - used by prometheus.py
|
||||
data[_metadata_variable_name][
|
||||
|
@ -175,7 +166,8 @@ async def add_litellm_data_to_request(
|
|||
if user_api_key_dict.allowed_model_region is not None:
|
||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||
|
||||
## [Enterprise Only] Add User-IP Address
|
||||
## [Enterprise Only]
|
||||
# Add User-IP Address
|
||||
requester_ip_address = ""
|
||||
if premium_user is True:
|
||||
# Only set the IP Address for Enterprise Users
|
||||
|
@ -188,6 +180,18 @@ async def add_litellm_data_to_request(
|
|||
requester_ip_address = request.client.host
|
||||
data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address
|
||||
|
||||
# Enterprise Only - Check if using tag based routing
|
||||
if llm_router and llm_router.enable_tag_filtering is True:
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
"router.enable_tag_filtering is on %s \n switched off router.enable_tag_filtering",
|
||||
CommonProxyErrors.not_premium_user.value,
|
||||
)
|
||||
llm_router.enable_tag_filtering = False
|
||||
else:
|
||||
if "tags" in data:
|
||||
data[_metadata_variable_name]["tags"] = data["tags"]
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
|
|
|
@ -4,16 +4,14 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model_info:
|
||||
tier: free # 👈 Key Change - set `tier`
|
||||
tags: ["free"] # 👈 Key Change
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
tier: paid # 👈 Key Change - set `tier`
|
||||
tags: ["paid"] # 👈 Key Change
|
||||
|
||||
router_settings:
|
||||
enable_tag_filtering: True # 👈 Key Change
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
|
||||
|
||||
master_key: sk-1234
|
|
@ -47,12 +47,12 @@ from litellm.assistants.main import AssistantDeleted
|
|||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||
from litellm.router_utils.client_initalization_utils import (
|
||||
set_client,
|
||||
should_initialize_sync_client,
|
||||
|
@ -145,6 +145,7 @@ class Router:
|
|||
content_policy_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
enable_tag_filtering: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
retry_policy: Optional[
|
||||
RetryPolicy
|
||||
|
@ -246,6 +247,7 @@ class Router:
|
|||
self.set_verbose = set_verbose
|
||||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
self.enable_tag_filtering = enable_tag_filtering
|
||||
if self.set_verbose == True:
|
||||
if debug_level == "INFO":
|
||||
verbose_router_logger.setLevel(logging.INFO)
|
||||
|
@ -4482,8 +4484,9 @@ class Router:
|
|||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
# check free / paid tier for each deployment
|
||||
healthy_deployments = await get_deployments_for_tier(
|
||||
# check if user wants to do tag based routing
|
||||
healthy_deployments = await get_deployments_for_tag(
|
||||
llm_router_instance=self,
|
||||
request_kwargs=request_kwargs,
|
||||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
"""
|
||||
Use this to route requests between free and paid tiers
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
|
||||
|
||||
class ModelInfo(TypedDict):
|
||||
tier: Literal["free", "paid"]
|
||||
|
||||
|
||||
class Deployment(TypedDict):
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
async def get_deployments_for_tier(
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
if "tier" in metadata:
|
||||
selected_tier: Literal["free", "paid"] = metadata["tier"]
|
||||
if healthy_deployments is None:
|
||||
return None
|
||||
|
||||
if selected_tier == "free":
|
||||
# get all deployments where model_info has tier = free
|
||||
free_deployments: List[Any] = []
|
||||
verbose_logger.debug(
|
||||
"Getting deployments in free tier, all_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
for deployment in healthy_deployments:
|
||||
typed_deployment = cast(Deployment, deployment)
|
||||
if typed_deployment["model_info"]["tier"] == "free":
|
||||
free_deployments.append(deployment)
|
||||
verbose_logger.debug("free_deployments: %s", free_deployments)
|
||||
return free_deployments
|
||||
|
||||
elif selected_tier == "paid":
|
||||
# get all deployments where model_info has tier = paid
|
||||
paid_deployments: List[Any] = []
|
||||
for deployment in healthy_deployments:
|
||||
typed_deployment = cast(Deployment, deployment)
|
||||
if typed_deployment["model_info"]["tier"] == "paid":
|
||||
paid_deployments.append(deployment)
|
||||
verbose_logger.debug("paid_deployments: %s", paid_deployments)
|
||||
return paid_deployments
|
||||
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
79
litellm/router_strategy/tag_based_routing.py
Normal file
79
litellm/router_strategy/tag_based_routing.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Use this to route requests between free and paid tiers
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||
"""
|
||||
if llm_router_instance.enable_tag_filtering is not True:
|
||||
return healthy_deployments
|
||||
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
if healthy_deployments is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: healthy_deployments is None returning healthy_deployments"
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
request_tags = metadata.get("tags")
|
||||
|
||||
new_healthy_deployments = []
|
||||
if request_tags:
|
||||
verbose_logger.debug("parameter routing: router_keys: %s", request_tags)
|
||||
# example this can be router_keys=["free", "custom"]
|
||||
# get all deployments that have a superset of these router keys
|
||||
for deployment in healthy_deployments:
|
||||
deployment_litellm_params = deployment.get("litellm_params")
|
||||
deployment_tags = deployment_litellm_params.get("tags")
|
||||
|
||||
verbose_logger.debug(
|
||||
"deployment: %s, deployment_router_keys: %s",
|
||||
deployment,
|
||||
deployment_tags,
|
||||
)
|
||||
|
||||
if deployment_tags is None:
|
||||
continue
|
||||
|
||||
if set(request_tags).issubset(set(deployment_tags)):
|
||||
verbose_logger.debug(
|
||||
"adding deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
new_healthy_deployments.append(deployment)
|
||||
|
||||
return new_healthy_deployments
|
||||
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
|
@ -1,60 +0,0 @@
|
|||
"""
|
||||
Tests litellm pre_call_utils
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tier", ["free", "paid"])
|
||||
@pytest.mark.asyncio()
|
||||
async def test_adding_key_tier_to_request_metadata(tier):
|
||||
"""
|
||||
Tests if we can add tier: free/paid from key metadata to the request metadata
|
||||
"""
|
||||
data = {}
|
||||
|
||||
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"route": api_route,
|
||||
"path": api_route.path,
|
||||
"headers": [],
|
||||
}
|
||||
)
|
||||
new_data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
|
||||
proxy_config=ProxyConfig(),
|
||||
)
|
||||
|
||||
print("new_data", new_data)
|
||||
|
||||
assert new_data["metadata"]["tier"] == tier
|
|
@ -45,18 +45,21 @@ async def test_router_free_paid_tier():
|
|||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
"tags": ["free"],
|
||||
},
|
||||
"model_info": {"tier": "paid", "id": "very-expensive-model"},
|
||||
"model_info": {"id": "very-cheap-model"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
"tags": ["paid"],
|
||||
},
|
||||
"model_info": {"tier": "free", "id": "very-cheap-model"},
|
||||
"model_info": {"id": "very-expensive-model"},
|
||||
},
|
||||
]
|
||||
],
|
||||
enable_tag_filtering=True,
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
|
@ -64,7 +67,7 @@ async def test_router_free_paid_tier():
|
|||
response = await router.acompletion(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
metadata={"tier": "free"},
|
||||
metadata={"tags": ["free"]},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
||||
|
@ -79,7 +82,7 @@ async def test_router_free_paid_tier():
|
|||
response = await router.acompletion(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
metadata={"tier": "paid"},
|
||||
metadata={"tags": ["paid"]},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
|
@ -325,6 +325,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
## MOCK RESPONSES ##
|
||||
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
||||
|
||||
# routing params
|
||||
# use this for tag-based routing
|
||||
tags: Optional[List[str]]
|
||||
|
||||
|
||||
class DeploymentTypedDict(TypedDict):
|
||||
model_name: str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue