feat(litellm_pre_call_utils.py): support passing tags/spend logs metadata from keys/team metadata to request

This commit is contained in:
Krrish Dholakia 2024-08-21 08:12:45 -07:00
parent 09aacfc6a8
commit 04fc0bd7b3
2 changed files with 146 additions and 1 deletions

View file

@ -224,7 +224,7 @@ async def add_litellm_data_to_request(
user_api_key_dict, "team_alias", None user_api_key_dict, "team_alias", None
) )
### KEY-LEVEL Contorls ### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata key_metadata = user_api_key_dict.metadata
if "cache" in key_metadata: if "cache" in key_metadata:
data["cache"] = {} data["cache"] = {}
@ -233,6 +233,51 @@ async def add_litellm_data_to_request(
if k in SupportedCacheControls: if k in SupportedCacheControls:
data["cache"][k] = v data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
data[_metadata_variable_name]["spend_logs_metadata"].update(
key_metadata["spend_logs_metadata"]
)
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## TEAM-LEVEL SPEND LOGS/TAGS
team_metadata = user_api_key_dict.team_metadata or {}
if "tags" in team_metadata and team_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(team_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = team_metadata["tags"]
if "spend_logs_metadata" in team_metadata and isinstance(
team_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
data[_metadata_variable_name]["spend_logs_metadata"].update(
team_metadata["spend_logs_metadata"]
)
else:
data[_metadata_variable_name]["spend_logs_metadata"] = team_metadata[
"spend_logs_metadata"
]
# Team spend, budget - used by prometheus.py # Team spend, budget - used by prometheus.py
data[_metadata_variable_name][ data[_metadata_variable_name][
"user_api_key_team_max_budget" "user_api_key_team_max_budget"

View file

@ -104,3 +104,103 @@ async def test_traceparent_not_added_by_default(endpoint, mock_request):
assert "traceparent" not in _extra_headers assert "traceparent" not in _extra_headers
setattr(litellm.proxy.proxy_server, "open_telemetry_logger", None) setattr(litellm.proxy.proxy_server, "open_telemetry_logger", None)
@pytest.mark.parametrize(
"request_tags", [None, ["request_tag1", "request_tag2", "request_tag3"]]
)
@pytest.mark.parametrize(
"request_sl_metadata", [None, {"request_key": "request_value"}]
)
@pytest.mark.parametrize("key_tags", [None, ["key_tag1", "key_tag2", "key_tag3"]])
@pytest.mark.parametrize("key_sl_metadata", [None, {"key_key": "key_value"}])
@pytest.mark.parametrize("team_tags", [None, ["team_tag1", "team_tag2", "team_tag3"]])
@pytest.mark.parametrize("team_sl_metadata", [None, {"team_key": "team_value"}])
@pytest.mark.asyncio
async def test_add_key_or_team_level_spend_logs_metadata_to_request(
mock_request,
request_tags,
request_sl_metadata,
team_tags,
key_sl_metadata,
team_sl_metadata,
key_tags,
):
## COMPLETE LIST OF TAGS
all_tags = []
if request_tags is not None:
print("Request Tags - {}".format(request_tags))
all_tags.extend(request_tags)
if key_tags is not None:
print("Key Tags - {}".format(key_tags))
all_tags.extend(key_tags)
if team_tags is not None:
print("Team Tags - {}".format(team_tags))
all_tags.extend(team_tags)
## COMPLETE SPEND_LOGS METADATA
all_sl_metadata = {}
if request_sl_metadata is not None:
all_sl_metadata.update(request_sl_metadata)
if key_sl_metadata is not None:
all_sl_metadata.update(key_sl_metadata)
if team_sl_metadata is not None:
all_sl_metadata.update(team_sl_metadata)
print(f"team_sl_metadata: {team_sl_metadata}")
mock_request.url.path = "/chat/completions"
key_metadata = {
"tags": key_tags,
"spend_logs_metadata": key_sl_metadata,
}
team_metadata = {
"tags": team_tags,
"spend_logs_metadata": team_sl_metadata,
}
user_api_key_dict = UserAPIKeyAuth(
api_key="test_api_key",
user_id="test_user_id",
org_id="test_org_id",
metadata=key_metadata,
team_metadata=team_metadata,
)
proxy_config = Mock()
data = {"metadata": {}}
if request_tags is not None:
data["metadata"]["tags"] = request_tags
if request_sl_metadata is not None:
data["metadata"]["spend_logs_metadata"] = request_sl_metadata
print(data)
new_data = await add_litellm_data_to_request(
data, mock_request, user_api_key_dict, proxy_config
)
print("New Data: {}".format(new_data))
print("all_tags: {}".format(all_tags))
assert "metadata" in new_data
if len(all_tags) == 0:
assert "tags" not in new_data["metadata"], "Expected=No tags. Got={}".format(
new_data["metadata"]["tags"]
)
else:
assert new_data["metadata"]["tags"] == all_tags, "Expected={}. Got={}".format(
all_tags, new_data["metadata"].get("tags", None)
)
if len(all_sl_metadata.keys()) == 0:
assert (
"spend_logs_metadata" not in new_data["metadata"]
), "Expected=No spend logs metadata. Got={}".format(
new_data["metadata"]["spend_logs_metadata"]
)
else:
assert (
new_data["metadata"]["spend_logs_metadata"] == all_sl_metadata
), "Expected={}. Got={}".format(
all_sl_metadata, new_data["metadata"]["spend_logs_metadata"]
)
# assert (
# new_data["metadata"]["spend_logs_metadata"] == metadata["spend_logs_metadata"]
# )