diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 9166e1341..6c5d556c1 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -224,7 +224,7 @@ async def add_litellm_data_to_request( user_api_key_dict, "team_alias", None ) - ### KEY-LEVEL Contorls + ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata if "cache" in key_metadata: data["cache"] = {} @@ -233,6 +233,51 @@ async def add_litellm_data_to_request( if k in SupportedCacheControls: data["cache"][k] = v + ## KEY-LEVEL SPEND LOGS / TAGS + if "tags" in key_metadata and key_metadata["tags"] is not None: + if "tags" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["tags"], list + ): + data[_metadata_variable_name]["tags"].extend(key_metadata["tags"]) + else: + data[_metadata_variable_name]["tags"] = key_metadata["tags"] + if "spend_logs_metadata" in key_metadata and isinstance( + key_metadata["spend_logs_metadata"], dict + ): + if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["spend_logs_metadata"], dict + ): + data[_metadata_variable_name]["spend_logs_metadata"].update( + key_metadata["spend_logs_metadata"] + ) + else: + data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ + "spend_logs_metadata" + ] + + ## TEAM-LEVEL SPEND LOGS/TAGS + team_metadata = user_api_key_dict.team_metadata or {} + if "tags" in team_metadata and team_metadata["tags"] is not None: + if "tags" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["tags"], list + ): + data[_metadata_variable_name]["tags"].extend(team_metadata["tags"]) + else: + data[_metadata_variable_name]["tags"] = team_metadata["tags"] + if "spend_logs_metadata" in team_metadata and isinstance( + team_metadata["spend_logs_metadata"], dict + ): + if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( + data[_metadata_variable_name]["spend_logs_metadata"], dict + ): + data[_metadata_variable_name]["spend_logs_metadata"].update( + team_metadata["spend_logs_metadata"] + ) + else: + data[_metadata_variable_name]["spend_logs_metadata"] = team_metadata[ + "spend_logs_metadata" + ] + # Team spend, budget - used by prometheus.py data[_metadata_variable_name][ "user_api_key_team_max_budget" diff --git a/litellm/tests/test_proxy_utils.py b/litellm/tests/test_proxy_utils.py index c8f6916e2..63361b09a 100644 --- a/litellm/tests/test_proxy_utils.py +++ b/litellm/tests/test_proxy_utils.py @@ -104,3 +104,103 @@ async def test_traceparent_not_added_by_default(endpoint, mock_request): assert "traceparent" not in _extra_headers setattr(litellm.proxy.proxy_server, "open_telemetry_logger", None) + + +@pytest.mark.parametrize( + "request_tags", [None, ["request_tag1", "request_tag2", "request_tag3"]] +) +@pytest.mark.parametrize( + "request_sl_metadata", [None, {"request_key": "request_value"}] +) +@pytest.mark.parametrize("key_tags", [None, ["key_tag1", "key_tag2", "key_tag3"]]) +@pytest.mark.parametrize("key_sl_metadata", [None, {"key_key": "key_value"}]) +@pytest.mark.parametrize("team_tags", [None, ["team_tag1", "team_tag2", "team_tag3"]]) +@pytest.mark.parametrize("team_sl_metadata", [None, {"team_key": "team_value"}]) +@pytest.mark.asyncio +async def test_add_key_or_team_level_spend_logs_metadata_to_request( + mock_request, + request_tags, + request_sl_metadata, + team_tags, + key_sl_metadata, + team_sl_metadata, + key_tags, +): + ## COMPLETE LIST OF TAGS + all_tags = [] + if request_tags is not None: + print("Request Tags - {}".format(request_tags)) + all_tags.extend(request_tags) + if key_tags is not None: + print("Key Tags - {}".format(key_tags)) + all_tags.extend(key_tags) + if team_tags is not None: + print("Team Tags - {}".format(team_tags)) + all_tags.extend(team_tags) + + ## COMPLETE SPEND_LOGS METADATA + all_sl_metadata = {} + if request_sl_metadata is not None: + all_sl_metadata.update(request_sl_metadata) + if key_sl_metadata is not None: + all_sl_metadata.update(key_sl_metadata) + if team_sl_metadata is not None: + all_sl_metadata.update(team_sl_metadata) + + print(f"team_sl_metadata: {team_sl_metadata}") + mock_request.url.path = "/chat/completions" + key_metadata = { + "tags": key_tags, + "spend_logs_metadata": key_sl_metadata, + } + team_metadata = { + "tags": team_tags, + "spend_logs_metadata": team_sl_metadata, + } + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", + user_id="test_user_id", + org_id="test_org_id", + metadata=key_metadata, + team_metadata=team_metadata, + ) + proxy_config = Mock() + + data = {"metadata": {}} + if request_tags is not None: + data["metadata"]["tags"] = request_tags + if request_sl_metadata is not None: + data["metadata"]["spend_logs_metadata"] = request_sl_metadata + + print(data) + new_data = await add_litellm_data_to_request( + data, mock_request, user_api_key_dict, proxy_config + ) + + print("New Data: {}".format(new_data)) + print("all_tags: {}".format(all_tags)) + assert "metadata" in new_data + if len(all_tags) == 0: + assert "tags" not in new_data["metadata"], "Expected=No tags. Got={}".format( + new_data["metadata"]["tags"] + ) + else: + assert new_data["metadata"]["tags"] == all_tags, "Expected={}. Got={}".format( + all_tags, new_data["metadata"].get("tags", None) + ) + + if len(all_sl_metadata.keys()) == 0: + assert ( + "spend_logs_metadata" not in new_data["metadata"] + ), "Expected=No spend logs metadata. Got={}".format( + new_data["metadata"]["spend_logs_metadata"] + ) + else: + assert ( + new_data["metadata"]["spend_logs_metadata"] == all_sl_metadata + ), "Expected={}. Got={}".format( + all_sl_metadata, new_data["metadata"]["spend_logs_metadata"] + ) + # assert ( + # new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"] + # )