forked from phoenix/litellm-mirror
Merge pull request #4789 from BerriAI/litellm_router_refactor
[Feat-Router] - Tag based routing
This commit is contained in:
commit
f04397e19a
12 changed files with 256 additions and 258 deletions
|
@ -1,102 +0,0 @@
|
||||||
# 💸 Free, Paid Tier Routing
|
|
||||||
|
|
||||||
Route Virtual Keys on `free tier` to cheaper models
|
|
||||||
|
|
||||||
### 1. Define free, paid tier models on config.yaml
|
|
||||||
|
|
||||||
:::info
|
|
||||||
Requests with `model=gpt-4` will be routed to either `openai/fake` or `openai/gpt-4o` depending on which tier the virtual key is on
|
|
||||||
:::
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: openai/fake
|
|
||||||
api_key: fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
model_info:
|
|
||||||
tier: free # 👈 Key Change - set `tier to paid or free`
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: openai/gpt-4o
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
model_info:
|
|
||||||
tier: paid # 👈 Key Change - set `tier to paid or free`
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Create Virtual Keys with pricing `tier=free`
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"metadata": {"tier": "free"}
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Make Request with Key on `Free Tier`
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl -i http://localhost:4000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
curl -i http://localhost:4000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-inxzoSurQsjog9gPrVOCcA" \
|
|
||||||
-d '{
|
|
||||||
"model": "gpt-4",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, Claude gm!"}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Expected Response**
|
|
||||||
|
|
||||||
If this worked as expected then `x-litellm-model-api-base` should be `https://exampleopenaiendpoint-production.up.railway.app/` in the response headers
|
|
||||||
|
|
||||||
```shell
|
|
||||||
x-litellm-model-api-base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
|
|
||||||
{"id":"chatcmpl-657b750f581240c1908679ed94b31bfe","choices":[{"finish_reason":"stop","index":0,"message":{"content":"\n\nHello there, how may I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1677652288,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":"fp_44709d6fcb","usage":{"completion_tokens":12,"prompt_tokens":9,"total_tokens":21}}%
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### 4. Create Virtual Keys with pricing `tier=paid`
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data '{
|
|
||||||
"metadata": {"tier": "paid"}
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Make Request with Key on `Paid Tier`
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl -i http://localhost:4000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-mnJoeSc6jFjzZr256q-iqA" \
|
|
||||||
-d '{
|
|
||||||
"model": "gpt-4",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, Claude gm!"}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Expected Response**
|
|
||||||
|
|
||||||
If this worked as expected then `x-litellm-model-api-base` should be `https://api.openai.com` in the response headers
|
|
||||||
|
|
||||||
```shell
|
|
||||||
x-litellm-model-api-base: https://api.openai.com
|
|
||||||
|
|
||||||
{"id":"chatcmpl-9mW75EbJCgwmLcO0M5DmwxpiBgWdc","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Good morning! How can I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1721350215,"model":"gpt-4o-2024-05-13","object":"chat.completion","system_fingerprint":"fp_c4e5b6fa31","usage":{"completion_tokens":10,"prompt_tokens":12,"total_tokens":22}}
|
|
||||||
```
|
|
133
docs/my-website/docs/proxy/tag_routing.md
Normal file
133
docs/my-website/docs/proxy/tag_routing.md
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
# 💸 Tag Based Routing
|
||||||
|
|
||||||
|
Route requests based on tags.
|
||||||
|
This is useful for implementing free / paid tiers for users
|
||||||
|
|
||||||
|
### 1. Define tags on config.yaml
|
||||||
|
|
||||||
|
- A request with `tags=["free"]` will get routed to `openai/fake`
|
||||||
|
- A request with `tags=["paid"]` will get routed to `openai/gpt-4o`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
tags: ["free"] # 👈 Key Change
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-4o
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
tags: ["paid"] # 👈 Key Change
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
enable_tag_filtering: True # 👈 Key Change
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Make Request with `tags=["free"]`
|
||||||
|
|
||||||
|
This request includes "tags": ["free"], which routes it to `openai/fake`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, Claude gm!"}
|
||||||
|
],
|
||||||
|
"tags": ["free"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
Expect to see the following response header when this works
|
||||||
|
```shell
|
||||||
|
x-litellm-model-api-base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
```
|
||||||
|
|
||||||
|
Response
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-33c534e3d70148218e2d62496b81270b",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "\n\nHello there, how may I assist you today?",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null,
|
||||||
|
"function_call": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"total_tokens": 21
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Make Request with `tags=["paid"]`
|
||||||
|
|
||||||
|
This request includes "tags": ["paid"], which routes it to `openai/gpt-4`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, Claude gm!"}
|
||||||
|
],
|
||||||
|
"tags": ["paid"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
Expect to see the following response header when this works
|
||||||
|
```shell
|
||||||
|
x-litellm-model-api-base: https://api.openai.com
|
||||||
|
```
|
||||||
|
|
||||||
|
Response
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-9maCcqQYTqdJrtvfakIawMOIUbEZx",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "Good morning! How can I assist you today?",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null,
|
||||||
|
"function_call": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1721365934,
|
||||||
|
"model": "gpt-4o-2024-05-13",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "fp_c4e5b6fa31",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 12,
|
||||||
|
"total_tokens": 22
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
|
@ -44,7 +44,7 @@ const sidebars = {
|
||||||
"proxy/cost_tracking",
|
"proxy/cost_tracking",
|
||||||
"proxy/self_serve",
|
"proxy/self_serve",
|
||||||
"proxy/virtual_keys",
|
"proxy/virtual_keys",
|
||||||
"proxy/free_paid_tier",
|
"proxy/tag_routing",
|
||||||
"proxy/users",
|
"proxy/users",
|
||||||
"proxy/team_budgets",
|
"proxy/team_budgets",
|
||||||
"proxy/customers",
|
"proxy/customers",
|
||||||
|
|
|
@ -735,6 +735,7 @@ def completion(
|
||||||
]
|
]
|
||||||
litellm_params = [
|
litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
|
"tags",
|
||||||
"acompletion",
|
"acompletion",
|
||||||
"atext_completion",
|
"atext_completion",
|
||||||
"text_completion",
|
"text_completion",
|
||||||
|
@ -3155,6 +3156,7 @@ def embedding(
|
||||||
"allowed_model_region",
|
"allowed_model_region",
|
||||||
"model_config",
|
"model_config",
|
||||||
"cooldown_time",
|
"cooldown_time",
|
||||||
|
"tags",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -4384,6 +4386,8 @@ def transcription(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
|
tags = kwargs.pop("tags", [])
|
||||||
|
|
||||||
drop_params = kwargs.get("drop_params", None)
|
drop_params = kwargs.get("drop_params", None)
|
||||||
client: Optional[
|
client: Optional[
|
||||||
Union[
|
Union[
|
||||||
|
@ -4556,6 +4560,7 @@ def speech(
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||||
|
tags = kwargs.pop("tags", [])
|
||||||
|
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
if response_format is not None:
|
if response_format is not None:
|
||||||
|
|
|
@ -75,7 +75,7 @@ async def add_litellm_data_to_request(
|
||||||
dict: The modified data dictionary.
|
dict: The modified data dictionary.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import llm_router, premium_user
|
||||||
|
|
||||||
safe_add_api_version_from_query_params(data, request)
|
safe_add_api_version_from_query_params(data, request)
|
||||||
|
|
||||||
|
@ -132,15 +132,6 @@ async def add_litellm_data_to_request(
|
||||||
for k, v in key_metadata["cache"].items():
|
for k, v in key_metadata["cache"].items():
|
||||||
if k in SupportedCacheControls:
|
if k in SupportedCacheControls:
|
||||||
data["cache"][k] = v
|
data["cache"][k] = v
|
||||||
if "tier" in key_metadata:
|
|
||||||
if premium_user is not True:
|
|
||||||
verbose_logger.warning(
|
|
||||||
"Trying to use free/paid tier feature. This will not be applied %s",
|
|
||||||
CommonProxyErrors.not_premium_user.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# add request tier to metadata
|
|
||||||
data[_metadata_variable_name]["tier"] = key_metadata["tier"]
|
|
||||||
|
|
||||||
# Team spend, budget - used by prometheus.py
|
# Team spend, budget - used by prometheus.py
|
||||||
data[_metadata_variable_name][
|
data[_metadata_variable_name][
|
||||||
|
@ -175,7 +166,8 @@ async def add_litellm_data_to_request(
|
||||||
if user_api_key_dict.allowed_model_region is not None:
|
if user_api_key_dict.allowed_model_region is not None:
|
||||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||||
|
|
||||||
## [Enterprise Only] Add User-IP Address
|
## [Enterprise Only]
|
||||||
|
# Add User-IP Address
|
||||||
requester_ip_address = ""
|
requester_ip_address = ""
|
||||||
if premium_user is True:
|
if premium_user is True:
|
||||||
# Only set the IP Address for Enterprise Users
|
# Only set the IP Address for Enterprise Users
|
||||||
|
@ -188,6 +180,18 @@ async def add_litellm_data_to_request(
|
||||||
requester_ip_address = request.client.host
|
requester_ip_address = request.client.host
|
||||||
data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address
|
data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address
|
||||||
|
|
||||||
|
# Enterprise Only - Check if using tag based routing
|
||||||
|
if llm_router and llm_router.enable_tag_filtering is True:
|
||||||
|
if premium_user is not True:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"router.enable_tag_filtering is on %s \n switched off router.enable_tag_filtering",
|
||||||
|
CommonProxyErrors.not_premium_user.value,
|
||||||
|
)
|
||||||
|
llm_router.enable_tag_filtering = False
|
||||||
|
else:
|
||||||
|
if "tags" in data:
|
||||||
|
data[_metadata_variable_name]["tags"] = data["tags"]
|
||||||
|
|
||||||
### TEAM-SPECIFIC PARAMS ###
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
if user_api_key_dict.team_id is not None:
|
if user_api_key_dict.team_id is not None:
|
||||||
team_config = await proxy_config.load_team_config(
|
team_config = await proxy_config.load_team_config(
|
||||||
|
|
|
@ -4,16 +4,14 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
model_info:
|
tags: ["free"] # 👈 Key Change
|
||||||
tier: free # 👈 Key Change - set `tier`
|
|
||||||
- model_name: gpt-4
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-4o
|
model: openai/gpt-4o
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
tags: ["paid"] # 👈 Key Change
|
||||||
tier: paid # 👈 Key Change - set `tier`
|
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
enable_tag_filtering: True # 👈 Key Change
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
|
||||||
|
|
|
@ -47,12 +47,12 @@ from litellm.assistants.main import AssistantDeleted
|
||||||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
|
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
|
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||||
from litellm.router_utils.client_initalization_utils import (
|
from litellm.router_utils.client_initalization_utils import (
|
||||||
set_client,
|
set_client,
|
||||||
should_initialize_sync_client,
|
should_initialize_sync_client,
|
||||||
|
@ -145,6 +145,7 @@ class Router:
|
||||||
content_policy_fallbacks: List = [],
|
content_policy_fallbacks: List = [],
|
||||||
model_group_alias: Optional[dict] = {},
|
model_group_alias: Optional[dict] = {},
|
||||||
enable_pre_call_checks: bool = False,
|
enable_pre_call_checks: bool = False,
|
||||||
|
enable_tag_filtering: bool = False,
|
||||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||||
retry_policy: Optional[
|
retry_policy: Optional[
|
||||||
RetryPolicy
|
RetryPolicy
|
||||||
|
@ -246,6 +247,7 @@ class Router:
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.enable_pre_call_checks = enable_pre_call_checks
|
self.enable_pre_call_checks = enable_pre_call_checks
|
||||||
|
self.enable_tag_filtering = enable_tag_filtering
|
||||||
if self.set_verbose == True:
|
if self.set_verbose == True:
|
||||||
if debug_level == "INFO":
|
if debug_level == "INFO":
|
||||||
verbose_router_logger.setLevel(logging.INFO)
|
verbose_router_logger.setLevel(logging.INFO)
|
||||||
|
@ -4482,8 +4484,9 @@ class Router:
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check free / paid tier for each deployment
|
# check if user wants to do tag based routing
|
||||||
healthy_deployments = await get_deployments_for_tier(
|
healthy_deployments = await get_deployments_for_tag(
|
||||||
|
llm_router_instance=self,
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
"""
|
|
||||||
Use this to route requests between free and paid tiers
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
|
||||||
from litellm.types.router import DeploymentTypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(TypedDict):
|
|
||||||
tier: Literal["free", "paid"]
|
|
||||||
|
|
||||||
|
|
||||||
class Deployment(TypedDict):
|
|
||||||
model_info: ModelInfo
|
|
||||||
|
|
||||||
|
|
||||||
async def get_deployments_for_tier(
|
|
||||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
|
||||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
|
||||||
"""
|
|
||||||
if request_kwargs is None:
|
|
||||||
verbose_logger.debug(
|
|
||||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
|
||||||
healthy_deployments,
|
|
||||||
)
|
|
||||||
return healthy_deployments
|
|
||||||
|
|
||||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
|
||||||
if "metadata" in request_kwargs:
|
|
||||||
metadata = request_kwargs["metadata"]
|
|
||||||
if "tier" in metadata:
|
|
||||||
selected_tier: Literal["free", "paid"] = metadata["tier"]
|
|
||||||
if healthy_deployments is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if selected_tier == "free":
|
|
||||||
# get all deployments where model_info has tier = free
|
|
||||||
free_deployments: List[Any] = []
|
|
||||||
verbose_logger.debug(
|
|
||||||
"Getting deployments in free tier, all_deployments: %s",
|
|
||||||
healthy_deployments,
|
|
||||||
)
|
|
||||||
for deployment in healthy_deployments:
|
|
||||||
typed_deployment = cast(Deployment, deployment)
|
|
||||||
if typed_deployment["model_info"]["tier"] == "free":
|
|
||||||
free_deployments.append(deployment)
|
|
||||||
verbose_logger.debug("free_deployments: %s", free_deployments)
|
|
||||||
return free_deployments
|
|
||||||
|
|
||||||
elif selected_tier == "paid":
|
|
||||||
# get all deployments where model_info has tier = paid
|
|
||||||
paid_deployments: List[Any] = []
|
|
||||||
for deployment in healthy_deployments:
|
|
||||||
typed_deployment = cast(Deployment, deployment)
|
|
||||||
if typed_deployment["model_info"]["tier"] == "paid":
|
|
||||||
paid_deployments.append(deployment)
|
|
||||||
verbose_logger.debug("paid_deployments: %s", paid_deployments)
|
|
||||||
return paid_deployments
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"no tier found in metadata, returning healthy_deployments: %s",
|
|
||||||
healthy_deployments,
|
|
||||||
)
|
|
||||||
return healthy_deployments
|
|
79
litellm/router_strategy/tag_based_routing.py
Normal file
79
litellm/router_strategy/tag_based_routing.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
"""
|
||||||
|
Use this to route requests between free and paid tiers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.router import DeploymentTypedDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.router import Router as _Router
|
||||||
|
|
||||||
|
LitellmRouter = _Router
|
||||||
|
else:
|
||||||
|
LitellmRouter = Any
|
||||||
|
|
||||||
|
|
||||||
|
async def get_deployments_for_tag(
|
||||||
|
llm_router_instance: LitellmRouter,
|
||||||
|
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||||
|
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||||
|
"""
|
||||||
|
if llm_router_instance.enable_tag_filtering is not True:
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
if request_kwargs is None:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||||
|
healthy_deployments,
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
if healthy_deployments is None:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"get_deployments_for_tier: healthy_deployments is None returning healthy_deployments"
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||||
|
if "metadata" in request_kwargs:
|
||||||
|
metadata = request_kwargs["metadata"]
|
||||||
|
request_tags = metadata.get("tags")
|
||||||
|
|
||||||
|
new_healthy_deployments = []
|
||||||
|
if request_tags:
|
||||||
|
verbose_logger.debug("parameter routing: router_keys: %s", request_tags)
|
||||||
|
# example this can be router_keys=["free", "custom"]
|
||||||
|
# get all deployments that have a superset of these router keys
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
deployment_litellm_params = deployment.get("litellm_params")
|
||||||
|
deployment_tags = deployment_litellm_params.get("tags")
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"deployment: %s, deployment_router_keys: %s",
|
||||||
|
deployment,
|
||||||
|
deployment_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
if deployment_tags is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if set(request_tags).issubset(set(deployment_tags)):
|
||||||
|
verbose_logger.debug(
|
||||||
|
"adding deployment with tags: %s, request tags: %s",
|
||||||
|
deployment_tags,
|
||||||
|
request_tags,
|
||||||
|
)
|
||||||
|
new_healthy_deployments.append(deployment)
|
||||||
|
|
||||||
|
return new_healthy_deployments
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"no tier found in metadata, returning healthy_deployments: %s",
|
||||||
|
healthy_deployments,
|
||||||
|
)
|
||||||
|
return healthy_deployments
|
|
@ -1,60 +0,0 @@
|
||||||
"""
|
|
||||||
Tests litellm pre_call_utils
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import Request
|
|
||||||
from fastapi.routing import APIRoute
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
||||||
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# this file is to test litellm/proxy
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tier", ["free", "paid"])
|
|
||||||
@pytest.mark.asyncio()
|
|
||||||
async def test_adding_key_tier_to_request_metadata(tier):
|
|
||||||
"""
|
|
||||||
Tests if we can add tier: free/paid from key metadata to the request metadata
|
|
||||||
"""
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
|
|
||||||
request = Request(
|
|
||||||
{
|
|
||||||
"type": "http",
|
|
||||||
"method": "POST",
|
|
||||||
"route": api_route,
|
|
||||||
"path": api_route.path,
|
|
||||||
"headers": [],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
new_data = await add_litellm_data_to_request(
|
|
||||||
data=data,
|
|
||||||
request=request,
|
|
||||||
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
|
|
||||||
proxy_config=ProxyConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("new_data", new_data)
|
|
||||||
|
|
||||||
assert new_data["metadata"]["tier"] == tier
|
|
|
@ -45,18 +45,21 @@ async def test_router_free_paid_tier():
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
"tags": ["free"],
|
||||||
},
|
},
|
||||||
"model_info": {"tier": "paid", "id": "very-expensive-model"},
|
"model_info": {"id": "very-cheap-model"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||||
|
"tags": ["paid"],
|
||||||
},
|
},
|
||||||
"model_info": {"tier": "free", "id": "very-cheap-model"},
|
"model_info": {"id": "very-expensive-model"},
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
|
enable_tag_filtering=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
@ -64,7 +67,7 @@ async def test_router_free_paid_tier():
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
metadata={"tier": "free"},
|
metadata={"tags": ["free"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Response: ", response)
|
print("Response: ", response)
|
||||||
|
@ -79,7 +82,7 @@ async def test_router_free_paid_tier():
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="gpt-4",
|
model="gpt-4",
|
||||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
metadata={"tier": "paid"},
|
metadata={"tags": ["paid"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Response: ", response)
|
print("Response: ", response)
|
|
@ -325,6 +325,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
## MOCK RESPONSES ##
|
## MOCK RESPONSES ##
|
||||||
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
mock_response: Optional[Union[str, ModelResponse, Exception]]
|
||||||
|
|
||||||
|
# routing params
|
||||||
|
# use this for tag-based routing
|
||||||
|
tags: Optional[List[str]]
|
||||||
|
|
||||||
|
|
||||||
class DeploymentTypedDict(TypedDict):
|
class DeploymentTypedDict(TypedDict):
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue