From 9393434d01c65e048748bf5217a413d5e769c7f6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 27 Nov 2024 18:40:33 -0800 Subject: [PATCH] (fix) tag merging / aggregation logic (#6932) * use 1 helper to merge tags + ensure unique ness * test_add_litellm_data_to_request_duplicate_tags * fix _merge_tags * fix proxy utils test --- litellm/proxy/litellm_pre_call_utils.py | 46 ++++++--- tests/proxy_unit_tests/test_proxy_utils.py | 105 +++++++++++++++++++++ 2 files changed, 139 insertions(+), 12 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 3d1d3b491..6ac792696 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -288,12 +288,12 @@ class LiteLLMProxyRequestSetup: ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: - if "tags" in data[_metadata_variable_name] and isinstance( - data[_metadata_variable_name]["tags"], list - ): - data[_metadata_variable_name]["tags"].extend(key_metadata["tags"]) - else: - data[_metadata_variable_name]["tags"] = key_metadata["tags"] + data[_metadata_variable_name]["tags"] = ( + LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=key_metadata["tags"], + ) + ) if "spend_logs_metadata" in key_metadata and isinstance( key_metadata["spend_logs_metadata"], dict ): @@ -319,6 +319,30 @@ class LiteLLMProxyRequestSetup: data["disable_fallbacks"] = key_metadata["disable_fallbacks"] return data + @staticmethod + def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list: + """ + Helper function to merge two lists of tags, ensuring no duplicates. + + Args: + request_tags (Optional[list]): List of tags from the original request + tags_to_add (Optional[list]): List of tags to add + + Returns: + list: Combined list of unique tags + """ + final_tags = [] + + if request_tags and isinstance(request_tags, list): + final_tags.extend(request_tags) + + if tags_to_add and isinstance(tags_to_add, list): + for tag in tags_to_add: + if tag not in final_tags: + final_tags.append(tag) + + return final_tags + async def add_litellm_data_to_request( # noqa: PLR0915 data: dict, @@ -442,12 +466,10 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ## TEAM-LEVEL SPEND LOGS/TAGS team_metadata = user_api_key_dict.team_metadata or {} if "tags" in team_metadata and team_metadata["tags"] is not None: - if "tags" in data[_metadata_variable_name] and isinstance( - data[_metadata_variable_name]["tags"], list - ): - data[_metadata_variable_name]["tags"].extend(team_metadata["tags"]) - else: - data[_metadata_variable_name]["tags"] = team_metadata["tags"] + data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags( + request_tags=data[_metadata_variable_name].get("tags"), + tags_to_add=team_metadata["tags"], + ) if "spend_logs_metadata" in team_metadata and isinstance( team_metadata["spend_logs_metadata"], dict ): diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 607e54225..1df6b82ed 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -574,3 +574,108 @@ def test_get_docs_url(env_vars, expected_url): result = _get_docs_url() assert result == expected_url + + +@pytest.mark.parametrize( + "request_tags, tags_to_add, expected_tags", + [ + (None, None, []), # both None + (["tag1", "tag2"], None, ["tag1", "tag2"]), # tags_to_add is None + (None, ["tag3", "tag4"], ["tag3", "tag4"]), # request_tags is None + ( + ["tag1", "tag2"], + ["tag3", "tag4"], + ["tag1", "tag2", "tag3", "tag4"], + ), # both have unique tags + ( + ["tag1", "tag2"], + ["tag2", "tag3"], + ["tag1", "tag2", "tag3"], + ), # overlapping tags + ([], [], []), # both empty lists + ("not_a_list", ["tag1"], ["tag1"]), # request_tags invalid type + (["tag1"], "not_a_list", ["tag1"]), # tags_to_add invalid type + ( + ["tag1"], + ["tag1", "tag2"], + ["tag1", "tag2"], + ), # duplicate tags in inputs + ], +) +def test_merge_tags(request_tags, tags_to_add, expected_tags): + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + + result = LiteLLMProxyRequestSetup._merge_tags( + request_tags=request_tags, tags_to_add=tags_to_add + ) + + assert isinstance(result, list) + assert sorted(result) == sorted(expected_tags) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "key_tags, request_tags, expected_tags", + [ + # exact duplicates + (["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"]), + # partial duplicates + ( + ["tag1", "tag2", "tag3"], + ["tag2", "tag3", "tag4"], + ["tag1", "tag2", "tag3", "tag4"], + ), + # duplicates within key tags + (["tag1", "tag2"], ["tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]), + # duplicates within request tags + (["tag1", "tag2"], ["tag2", "tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]), + # case sensitive duplicates + (["Tag1", "TAG2"], ["tag1", "tag2"], ["Tag1", "TAG2", "tag1", "tag2"]), + ], +) +async def test_add_litellm_data_to_request_duplicate_tags( + key_tags, request_tags, expected_tags +): + """ + Test to verify duplicate tags between request and key metadata are handled correctly + + + Aggregation logic when checking spend can be impacted if duplicate tags are not handled correctly. + + User feedback: + "If I register my key with tag1 and + also pass the same tag1 when using the key + then I see tag1 twice in the + LiteLLM_SpendLogs table request_tags column. This can mess up aggregation logic" + """ + mock_request = Mock(spec=Request) + mock_request.url.path = "/chat/completions" + mock_request.query_params = {} + mock_request.headers = {} + + # Setup key with tags in metadata + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="test_user_id", + org_id="test_org_id", + metadata={"tags": key_tags}, + ) + + # Setup request data with tags + data = {"metadata": {"tags": request_tags}} + + # Process request + proxy_config = Mock() + result = await add_litellm_data_to_request( + data=data, + request=mock_request, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + + # Verify results + assert "metadata" in result + assert "tags" in result["metadata"] + assert sorted(result["metadata"]["tags"]) == sorted( + expected_tags + ), f"Expected {expected_tags}, got {result['metadata']['tags']}"