forked from phoenix/litellm-mirror
(fix) tag merging / aggregation logic (#6932)
* use 1 helper to merge tags + ensure unique ness * test_add_litellm_data_to_request_duplicate_tags * fix _merge_tags * fix proxy utils test
This commit is contained in:
parent
d6181b2c9f
commit
9393434d01
2 changed files with 139 additions and 12 deletions
|
@ -288,12 +288,12 @@ class LiteLLMProxyRequestSetup:
|
|||
|
||||
## KEY-LEVEL SPEND LOGS / TAGS
|
||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["tags"], list
|
||||
):
|
||||
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
||||
else:
|
||||
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
||||
data[_metadata_variable_name]["tags"] = (
|
||||
LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
)
|
||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||
key_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
|
@ -319,6 +319,30 @@ class LiteLLMProxyRequestSetup:
|
|||
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list:
|
||||
"""
|
||||
Helper function to merge two lists of tags, ensuring no duplicates.
|
||||
|
||||
Args:
|
||||
request_tags (Optional[list]): List of tags from the original request
|
||||
tags_to_add (Optional[list]): List of tags to add
|
||||
|
||||
Returns:
|
||||
list: Combined list of unique tags
|
||||
"""
|
||||
final_tags = []
|
||||
|
||||
if request_tags and isinstance(request_tags, list):
|
||||
final_tags.extend(request_tags)
|
||||
|
||||
if tags_to_add and isinstance(tags_to_add, list):
|
||||
for tag in tags_to_add:
|
||||
if tag not in final_tags:
|
||||
final_tags.append(tag)
|
||||
|
||||
return final_tags
|
||||
|
||||
|
||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data: dict,
|
||||
|
@ -442,12 +466,10 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
## TEAM-LEVEL SPEND LOGS/TAGS
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["tags"], list
|
||||
):
|
||||
data[_metadata_variable_name]["tags"].extend(team_metadata["tags"])
|
||||
else:
|
||||
data[_metadata_variable_name]["tags"] = team_metadata["tags"]
|
||||
data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=team_metadata["tags"],
|
||||
)
|
||||
if "spend_logs_metadata" in team_metadata and isinstance(
|
||||
team_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
|
|
|
@ -574,3 +574,108 @@ def test_get_docs_url(env_vars, expected_url):
|
|||
|
||||
result = _get_docs_url()
|
||||
assert result == expected_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_tags, tags_to_add, expected_tags",
|
||||
[
|
||||
(None, None, []), # both None
|
||||
(["tag1", "tag2"], None, ["tag1", "tag2"]), # tags_to_add is None
|
||||
(None, ["tag3", "tag4"], ["tag3", "tag4"]), # request_tags is None
|
||||
(
|
||||
["tag1", "tag2"],
|
||||
["tag3", "tag4"],
|
||||
["tag1", "tag2", "tag3", "tag4"],
|
||||
), # both have unique tags
|
||||
(
|
||||
["tag1", "tag2"],
|
||||
["tag2", "tag3"],
|
||||
["tag1", "tag2", "tag3"],
|
||||
), # overlapping tags
|
||||
([], [], []), # both empty lists
|
||||
("not_a_list", ["tag1"], ["tag1"]), # request_tags invalid type
|
||||
(["tag1"], "not_a_list", ["tag1"]), # tags_to_add invalid type
|
||||
(
|
||||
["tag1"],
|
||||
["tag1", "tag2"],
|
||||
["tag1", "tag2"],
|
||||
), # duplicate tags in inputs
|
||||
],
|
||||
)
|
||||
def test_merge_tags(request_tags, tags_to_add, expected_tags):
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
|
||||
result = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=request_tags, tags_to_add=tags_to_add
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert sorted(result) == sorted(expected_tags)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"key_tags, request_tags, expected_tags",
|
||||
[
|
||||
# exact duplicates
|
||||
(["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"]),
|
||||
# partial duplicates
|
||||
(
|
||||
["tag1", "tag2", "tag3"],
|
||||
["tag2", "tag3", "tag4"],
|
||||
["tag1", "tag2", "tag3", "tag4"],
|
||||
),
|
||||
# duplicates within key tags
|
||||
(["tag1", "tag2"], ["tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
||||
# duplicates within request tags
|
||||
(["tag1", "tag2"], ["tag2", "tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
||||
# case sensitive duplicates
|
||||
(["Tag1", "TAG2"], ["tag1", "tag2"], ["Tag1", "TAG2", "tag1", "tag2"]),
|
||||
],
|
||||
)
|
||||
async def test_add_litellm_data_to_request_duplicate_tags(
|
||||
key_tags, request_tags, expected_tags
|
||||
):
|
||||
"""
|
||||
Test to verify duplicate tags between request and key metadata are handled correctly
|
||||
|
||||
|
||||
Aggregation logic when checking spend can be impacted if duplicate tags are not handled correctly.
|
||||
|
||||
User feedback:
|
||||
"If I register my key with tag1 and
|
||||
also pass the same tag1 when using the key
|
||||
then I see tag1 twice in the
|
||||
LiteLLM_SpendLogs table request_tags column. This can mess up aggregation logic"
|
||||
"""
|
||||
mock_request = Mock(spec=Request)
|
||||
mock_request.url.path = "/chat/completions"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {}
|
||||
|
||||
# Setup key with tags in metadata
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="test_user_id",
|
||||
org_id="test_org_id",
|
||||
metadata={"tags": key_tags},
|
||||
)
|
||||
|
||||
# Setup request data with tags
|
||||
data = {"metadata": {"tags": request_tags}}
|
||||
|
||||
# Process request
|
||||
proxy_config = Mock()
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=mock_request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert "metadata" in result
|
||||
assert "tags" in result["metadata"]
|
||||
assert sorted(result["metadata"]["tags"]) == sorted(
|
||||
expected_tags
|
||||
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue