Litellm dev 12 20 2024 p3 (#7339)

* fix(proxy_track_cost_callback.py): log to db if only end user param given

* fix: allows for jwt-auth based end user id spend tracking to work

* fix(utils.py): fix 'get_end_user_id_for_cost_tracking' to use 'user_api_key_end_user_id'

more stable - works with jwt-auth based end user tracking as well

* test(test_jwt.py): add e2e unit test to confirm end user cost tracking works for spend logs

* test: update test to use end_user api key hash param

* fix(langfuse.py): support end user cost tracking via jwt auth + langfuse

logs end user to langfuse if decoded from jwt token

* fix: fix linting errors

* test: fix test

* test: fix test

* fix: fix end user id extraction

* fix: run test earlier
This commit is contained in:
Krish Dholakia 2024-12-20 21:13:32 -08:00 committed by GitHub
parent 1b2ed0c344
commit 61b4c41c3c
13 changed files with 149 additions and 36 deletions

View file

@ -458,6 +458,18 @@ class LangFuseLogger:
tags = metadata.pop("tags", []) if supports_tags else []
standard_logging_object: Optional[StandardLoggingPayload] = cast(
Optional[StandardLoggingPayload],
kwargs.get("standard_logging_object", None),
)
if standard_logging_object is None:
end_user_id = None
else:
end_user_id = standard_logging_object["metadata"].get(
"user_api_key_end_user_id", None
)
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging
@ -541,7 +553,7 @@ class LangFuseLogger:
"version": clean_metadata.pop(
"trace_version", clean_metadata.get("version", None)
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
"user_id": user_id,
"user_id": end_user_id,
}
for key in list(
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
@ -567,10 +579,6 @@ class LangFuseLogger:
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
clean_metadata["litellm_response_cost"] = cost
if standard_logging_object is not None:
clean_metadata["hidden_params"] = standard_logging_object[

View file

@ -2619,6 +2619,7 @@ class StandardLoggingPayloadSetup:
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
user_api_key_end_user_id=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@ -3075,6 +3076,7 @@ def get_standard_logging_metadata(
spend_logs_metadata=None,
requester_ip_address=None,
requester_metadata=None,
user_api_key_end_user_id=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys

View file

@ -20,9 +20,4 @@ model_list:
api_version: "2024-05-01-preview"
litellm_settings:
default_team_settings:
- team_id: c91e32bb-0f2a-4aa1-86c4-307ca2e03ea3
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: my-fake-key
langfuse_secret: my-fake-secret

View file

@ -608,8 +608,12 @@ async def user_api_key_auth( # noqa: PLR0915
end_user_params = {}
if "user" in request_data:
try:
end_user_id = request_data["user"]
end_user_params["end_user_id"] = end_user_id
# get end-user object
_end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
end_user_id=end_user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
@ -621,7 +625,6 @@ async def user_api_key_auth( # noqa: PLR0915
)
if _end_user_object.litellm_budget_table is not None:
budget_info = _end_user_object.litellm_budget_table
end_user_params["end_user_id"] = _end_user_object.user_id
if budget_info.tpm_limit is not None:
end_user_params["end_user_tpm_limit"] = (
budget_info.tpm_limit

View file

@ -60,7 +60,7 @@ def _safe_get_request_headers(request: Optional[Request]) -> dict:
return {}
return dict(request.headers)
except Exception as e:
verbose_proxy_logger.exception(
verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e)
)
return {}

View file

@ -1,6 +1,6 @@
import asyncio
import traceback
from typing import Optional, Union
from typing import Optional, Union, cast
import litellm
from litellm._logging import verbose_proxy_logger
@ -36,10 +36,10 @@ async def _PROXY_track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
@ -61,7 +61,12 @@ async def _PROXY_track_cost_callback(
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
)
if user_api_key is not None or user_id is not None or team_id is not None:
if _should_track_cost_callback(
user_api_key=user_api_key,
user_id=user_id,
team_id=team_id,
end_user_id=end_user_id,
):
## UPDATE DATABASE
await update_database(
token=user_api_key,
@ -128,3 +133,22 @@ async def _PROXY_track_cost_callback(
)
)
verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e))
def _should_track_cost_callback(
user_api_key: Optional[str],
user_id: Optional[str],
team_id: Optional[str],
end_user_id: Optional[str],
) -> bool:
"""
Determine if the cost callback should be tracked based on the kwargs
"""
if (
user_api_key is not None
or user_id is not None
or team_id is not None
or end_user_id is not None
):
return True
return False

View file

@ -72,8 +72,12 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
except KeyError:
pass
except Exception as e:
verbose_logger.error("error checking api version in query params: %s", str(e))
verbose_logger.exception(
"error checking api version in query params: %s", str(e)
)
def convert_key_logging_metadata_to_callback(
@ -266,6 +270,7 @@ class LiteLLMProxyRequestSetup:
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
)
return user_api_key_logged_metadata

View file

@ -1420,6 +1420,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_team_id: Optional[str]
user_api_key_user_id: Optional[str]
user_api_key_team_alias: Optional[str]
user_api_key_end_user_id: Optional[str]
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):

View file

@ -6269,7 +6269,13 @@ def get_end_user_id_for_cost_tracking(
service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
_metadata = cast(dict, litellm_params.get("metadata", {}) or {})
end_user_id = cast(
Optional[str],
litellm_params.get("user_api_key_end_user_id")
or _metadata.get("user_api_key_end_user_id"),
)
if litellm.disable_end_user_cost_tracking:
return None
if (
@ -6277,7 +6283,7 @@ def get_end_user_id_for_cost_tracking(
and litellm.disable_end_user_cost_tracking_prometheus_only
):
return None
return proxy_server_request.get("body", {}).get("user", None)
return end_user_id
def is_prompt_caching_valid_prompt(

View file

@ -1113,8 +1113,8 @@ def test_models_by_provider():
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
[
({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
({"user_api_key_end_user_id": "123"}, False, "123"),
({"user_api_key_end_user_id": "123"}, True, None),
],
)
def test_get_end_user_id_for_cost_tracking(
@ -1133,8 +1133,8 @@ def test_get_end_user_id_for_cost_tracking(
"litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id",
[
({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
({"user_api_key_end_user_id": "123"}, False, "123"),
({"user_api_key_end_user_id": "123"}, True, None),
],
)
def test_get_end_user_id_for_cost_tracking_prometheus_only(

View file

@ -270,6 +270,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_alias",
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
]
_all_attributes = set(

View file

@ -22,7 +22,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request
from fastapi.routing import APIRoute
from fastapi.responses import Response
import litellm
from litellm.caching.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
@ -1035,6 +1036,7 @@ async def test_end_user_jwt_auth(monkeypatch):
from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.proxy_server import user_api_key_auth
import json
monkeypatch.delenv("JWT_AUDIENCE", None)
jwt_handler = JWTHandler()
@ -1094,21 +1096,70 @@ async def test_end_user_jwt_auth(monkeypatch):
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
request = Request(
{
"type": "http",
"route": api_route,
"path": "/chat/completions",
"headers": [(b"authorization", f"Bearer {bearer_token}".encode("latin-1"))],
"method": "POST",
}
)
async def return_body():
body_dict = {
"model": "openai/gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
# Serialize the dictionary to JSON and encode it to bytes
return json.dumps(body_dict).encode("utf-8")
request.body = return_body
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
{"enable_jwt_auth": True, "pass_through_all_models": True},
)
setattr(
litellm.proxy.proxy_server,
"llm_router",
MagicMock(),
)
setattr(litellm.proxy.proxy_server, "prisma_client", {})
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
from litellm.proxy.proxy_server import cost_tracking
cost_tracking()
result = await user_api_key_auth(request=request, api_key=bearer_token)
assert (
result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479"
) # jwt token decoded sub value
temp_response = Response()
from litellm.proxy.hooks.proxy_track_cost_callback import (
_should_track_cost_callback,
)
with patch.object(
litellm.proxy.hooks.proxy_track_cost_callback, "_should_track_cost_callback"
) as mock_client:
resp = await chat_completion(
request=request,
fastapi_response=temp_response,
model="gpt-4o",
user_api_key_dict=result,
)
assert resp is not None
await asyncio.sleep(1)
mock_client.assert_called_once()
mock_client.call_args.kwargs[
"end_user_id"
] == "81b3e52a-67a6-4efb-9645-70527e101479"

View file

@ -588,8 +588,6 @@ def test_call_with_end_user_over_budget(prisma_client):
request._url = URL(url="/chat/completions")
bearer_token = "Bearer sk-1234"
result = await user_api_key_auth(request=request, api_key=bearer_token)
async def return_body():
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}'
# return string as bytes
@ -597,6 +595,8 @@ def test_call_with_end_user_over_budget(prisma_client):
request.body = return_body
result = await user_api_key_auth(request=request, api_key=bearer_token)
# update spend using track_cost callback, make 2nd request, it should fail
from litellm import Choices, Message, ModelResponse, Usage
from litellm.proxy.proxy_server import (
@ -624,7 +624,7 @@ def test_call_with_end_user_over_budget(prisma_client):
"litellm_params": {
"metadata": {
"user_api_key": "sk-1234",
"user_api_key_user_id": user,
"user_api_key_end_user_id": user,
},
"proxy_server_request": {
"body": {
@ -653,6 +653,7 @@ def test_call_with_end_user_over_budget(prisma_client):
asyncio.run(test())
except Exception as e:
print(f"raised error: {e}, traceback: {traceback.format_exc()}")
error_detail = e.message
assert "Budget has been exceeded! Current" in error_detail
assert isinstance(e, ProxyException)
@ -3634,3 +3635,19 @@ async def test_enforce_unique_key_alias(prisma_client):
except Exception as e:
print("Unexpected error:", e)
pytest.fail(f"An unexpected error occurred: {str(e)}")
def test_should_track_cost_callback():
"""
Test that the should_track_cost_callback function works as expected
"""
from litellm.proxy.hooks.proxy_track_cost_callback import (
_should_track_cost_callback,
)
assert _should_track_cost_callback(
user_api_key=None,
user_id=None,
team_id=None,
end_user_id="1234",
)