mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Pass router tags in request headers - x-litellm-tags
(#8609)
* feat(litellm_pre_call_utils.py): support `x-litellm-tags` request header allow tag based routing + spend tracking via request headers * docs(request_headers.md): document new `x-litellm-tags` for tag based routing and spend tracking * docs(tag_routing.md): add to docs * fix(utils.py): only pass str values for openai metadata param * fix(utils.py): drop non-str values for metadata param to openai preview-feature, otel span was being sent in
This commit is contained in:
parent
7bfd816d3b
commit
2340f1b31f
9 changed files with 122 additions and 22 deletions
|
@ -8,6 +8,8 @@ Special headers that are supported by LiteLLM.
|
|||
|
||||
`x-litellm-enable-message-redaction`: Optional[bool]: Don't log the message content to logging integrations. Just track spend. [Learn More](./logging#redact-messages-response-content)
|
||||
|
||||
`x-litellm-tags`: Optional[str]: A comma separated list (e.g. `tag1,tag2,tag3`) of tags to use for [tag-based routing](./tag_routing) **OR** [spend-tracking](./enterprise.md#tracking-spend-for-custom-tags).
|
||||
|
||||
## Anthropic Headers
|
||||
|
||||
`anthropic-version` Optional[str]: The version of the Anthropic API to use.
|
||||
|
|
|
@ -143,6 +143,26 @@ Response
|
|||
}
|
||||
```
|
||||
|
||||
## Calling via Request Header
|
||||
|
||||
You can also call via request header `x-litellm-tags`
|
||||
|
||||
```shell
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'x-litellm-tags: free,my-custom-tag' \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, how'\''s it going 123456?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Setting Default Tags
|
||||
|
||||
Use this if you want all untagged requests to be routed to specific deployments
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
model_list:
|
||||
- model_name: azure-gpt-35-turbo
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import Headers
|
||||
|
@ -17,6 +17,7 @@ from litellm.proxy._types import (
|
|||
TeamCallbackMetadata,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.router import Router
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import (
|
||||
|
@ -407,6 +408,28 @@ class LiteLLMProxyRequestSetup:
|
|||
callback_vars=callback_vars_dict,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_request_tag_to_metadata(
|
||||
llm_router: Optional[Router],
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> Optional[List[str]]:
|
||||
tags = None
|
||||
|
||||
if llm_router and llm_router.enable_tag_filtering is True:
|
||||
# Check request headers for tags
|
||||
if "x-litellm-tags" in headers:
|
||||
if isinstance(headers["x-litellm-tags"], str):
|
||||
_tags = headers["x-litellm-tags"].split(",")
|
||||
tags = [tag.strip() for tag in _tags]
|
||||
elif isinstance(headers["x-litellm-tags"], list):
|
||||
tags = headers["x-litellm-tags"]
|
||||
# Check request body for tags
|
||||
if "tags" in data and isinstance(data["tags"], list):
|
||||
tags = data["tags"]
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data: dict,
|
||||
|
@ -611,10 +634,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
requester_ip_address = request.client.host
|
||||
data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address
|
||||
|
||||
# Enterprise Only - Check if using tag based routing
|
||||
if llm_router and llm_router.enable_tag_filtering is True:
|
||||
if "tags" in data:
|
||||
data[_metadata_variable_name]["tags"] = data["tags"]
|
||||
# Check if using tag based routing
|
||||
tags = LiteLLMProxyRequestSetup.add_request_tag_to_metadata(
|
||||
llm_router=llm_router,
|
||||
headers=dict(request.headers),
|
||||
data=data,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
data[_metadata_variable_name]["tags"] = tags
|
||||
|
||||
# Team Callbacks controls
|
||||
callback_settings_obj = _get_dynamic_logging_metadata(
|
||||
|
|
|
@ -634,7 +634,6 @@ class Router:
|
|||
"""
|
||||
if fallback_param is None:
|
||||
return
|
||||
|
||||
for fallback_dict in fallback_param:
|
||||
if not isinstance(fallback_dict, dict):
|
||||
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
|
||||
|
|
|
@ -19,6 +19,30 @@ else:
|
|||
LitellmRouter = Any
|
||||
|
||||
|
||||
def is_valid_deployment_tag(
|
||||
deployment_tags: List[str], request_tags: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a tag is valid
|
||||
"""
|
||||
|
||||
if any(tag in deployment_tags for tag in request_tags):
|
||||
verbose_logger.debug(
|
||||
"adding deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
return True
|
||||
elif "default" in deployment_tags:
|
||||
verbose_logger.debug(
|
||||
"adding default deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
model: str, # used to raise the correct error
|
||||
|
@ -71,19 +95,7 @@ async def get_deployments_for_tag(
|
|||
if deployment_tags is None:
|
||||
continue
|
||||
|
||||
if set(request_tags).issubset(set(deployment_tags)):
|
||||
verbose_logger.debug(
|
||||
"adding deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
new_healthy_deployments.append(deployment)
|
||||
elif "default" in deployment_tags:
|
||||
verbose_logger.debug(
|
||||
"adding default deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
if is_valid_deployment_tag(deployment_tags, request_tags):
|
||||
new_healthy_deployments.append(deployment)
|
||||
|
||||
if len(new_healthy_deployments) == 0:
|
||||
|
|
|
@ -6368,7 +6368,9 @@ def get_non_default_completion_params(kwargs: dict) -> dict:
|
|||
|
||||
def add_openai_metadata(metadata: dict) -> dict:
|
||||
"""
|
||||
Add metadata to openai optional parameters, excluding hidden params
|
||||
Add metadata to openai optional parameters, excluding hidden params.
|
||||
|
||||
OpenAI 'metadata' only supports string values.
|
||||
|
||||
Args:
|
||||
params (dict): Dictionary of API parameters
|
||||
|
@ -6380,5 +6382,10 @@ def add_openai_metadata(metadata: dict) -> dict:
|
|||
if metadata is None:
|
||||
return None
|
||||
# Only include non-hidden parameters
|
||||
visible_metadata = {k: v for k, v in metadata.items() if k != "hidden_params"}
|
||||
visible_metadata = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k != "hidden_params" and isinstance(v, (str))
|
||||
}
|
||||
|
||||
return visible_metadata.copy()
|
||||
|
|
|
@ -1966,3 +1966,22 @@ def test_get_applied_guardrails(test_case):
|
|||
|
||||
# Assert
|
||||
assert sorted(result) == sorted(test_case["expected"])
|
||||
|
||||
|
||||
def test_add_openai_metadata():
|
||||
from litellm.utils import add_openai_metadata
|
||||
|
||||
metadata = {
|
||||
"user_api_key_end_user_id": "123",
|
||||
"hidden_params": {"api_key": "123"},
|
||||
"litellm_parent_otel_span": MagicMock(),
|
||||
"none-val": None,
|
||||
"int-val": 1,
|
||||
"dict-val": {"a": 1, "b": 2},
|
||||
}
|
||||
|
||||
result = add_openai_metadata(metadata)
|
||||
|
||||
assert result == {
|
||||
"user_api_key_end_user_id": "123",
|
||||
}
|
||||
|
|
|
@ -217,3 +217,16 @@ async def test_error_from_tag_routing():
|
|||
assert RouterErrors.no_deployments_with_tag_routing.value in str(e)
|
||||
print("got expected exception = ", e)
|
||||
pass
|
||||
|
||||
|
||||
def test_tag_routing_with_list_of_tags():
|
||||
"""
|
||||
Test that the router can handle a list of tags
|
||||
"""
|
||||
from litellm.router_strategy.tag_based_routing import is_valid_deployment_tag
|
||||
|
||||
assert is_valid_deployment_tag(["teamA", "teamB"], ["teamA"])
|
||||
assert is_valid_deployment_tag(["teamA", "teamB"], ["teamA", "teamB"])
|
||||
assert is_valid_deployment_tag(["teamA", "teamB"], ["teamA", "teamC"])
|
||||
assert not is_valid_deployment_tag(["teamA", "teamB"], ["teamC"])
|
||||
assert not is_valid_deployment_tag(["teamA", "teamB"], [])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue