diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 607e54225..77e78621d 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -574,3 +574,108 @@ def test_get_docs_url(env_vars, expected_url): result = _get_docs_url() assert result == expected_url + + +@pytest.mark.parametrize( + "request_tags, tags_to_add, expected_tags", + [ + (None, None, []), # both None + (["tag1", "tag2"], None, ["tag1", "tag2"]), # tags_to_add is None + (None, ["tag3", "tag4"], ["tag3", "tag4"]), # request_tags is None + ( + ["tag1", "tag2"], + ["tag3", "tag4"], + ["tag1", "tag2", "tag3", "tag4"], + ), # both have unique tags + ( + ["tag1", "tag2"], + ["tag2", "tag3"], + ["tag1", "tag2", "tag3"], + ), # overlapping tags + ([], [], []), # both empty lists + ("not_a_list", ["tag1"], ["tag1"]), # request_tags invalid type + (["tag1"], "not_a_list", ["tag1"]), # tags_to_add invalid type + ( + ["tag1", "tag1"], + ["tag1", "tag2"], + ["tag1", "tag2"], + ), # duplicate tags in inputs + ], +) +def test_merge_tags(request_tags, tags_to_add, expected_tags): + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + + result = LiteLLMProxyRequestSetup._merge_tags( + request_tags=request_tags, tags_to_add=tags_to_add + ) + + assert isinstance(result, list) + assert sorted(result) == sorted(expected_tags) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "key_tags, request_tags, expected_tags", + [ + # exact duplicates + (["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"]), + # partial duplicates + ( + ["tag1", "tag2", "tag3"], + ["tag2", "tag3", "tag4"], + ["tag1", "tag2", "tag3", "tag4"], + ), + # duplicates within key tags + (["tag1", "tag1", "tag2"], ["tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]), + # duplicates within request tags + (["tag1", "tag2"], ["tag3", "tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]), + # case sensitive duplicates + (["Tag1", "TAG2"], ["tag1", "tag2"], ["Tag1", "TAG2", "tag1", "tag2"]), + ], +) +async def test_add_litellm_data_to_request_duplicate_tags( + key_tags, request_tags, expected_tags +): + """ + Test to verify duplicate tags between request and key metadata are handled correctly + + + Aggregation logic when checking spend can be impacted if duplicate tags are not handled correctly. + + User feedback: + "If I register my key with tag1 and + also pass the same tag1 when using the key + then I see tag1 twice in the + LiteLLM_SpendLogs table request_tags column. This can mess up aggregation logic" + """ + mock_request = Mock(spec=Request) + mock_request.url.path = "/chat/completions" + mock_request.query_params = {} + mock_request.headers = {} + + # Setup key with tags in metadata + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="test_user_id", + org_id="test_org_id", + metadata={"tags": key_tags}, + ) + + # Setup request data with tags + data = {"metadata": {"tags": request_tags}} + + # Process request + proxy_config = Mock() + result = await add_litellm_data_to_request( + data=data, + request=mock_request, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + ) + + # Verify results + assert "metadata" in result + assert "tags" in result["metadata"] + assert sorted(result["metadata"]["tags"]) == sorted( + expected_tags + ), f"Expected {expected_tags}, got {result['metadata']['tags']}"