Litellm dev 12 20 2024 p3 (#7339)

* fix(proxy_track_cost_callback.py): log to db if only end user param given

* fix: allows for jwt-auth based end user id spend tracking to work

* fix(utils.py): fix 'get_end_user_id_for_cost_tracking' to use 'user_api_key_end_user_id'

more stable - works with jwt-auth based end user tracking as well

* test(test_jwt.py): add e2e unit test to confirm end user cost tracking works for spend logs

* test: update test to use end_user api key hash param

* fix(langfuse.py): support end user cost tracking via jwt auth + langfuse

logs end user to langfuse if decoded from jwt token

* fix: fix linting errors

* test: fix test

* test: fix test

* fix: fix end user id extraction

* fix: run test earlier
This commit is contained in:
Krish Dholakia 2024-12-20 21:13:32 -08:00 committed by GitHub
parent 1b2ed0c344
commit 61b4c41c3c
13 changed files with 149 additions and 36 deletions

View file

@ -458,6 +458,18 @@ class LangFuseLogger:
tags = metadata.pop("tags", []) if supports_tags else [] tags = metadata.pop("tags", []) if supports_tags else []
standard_logging_object: Optional[StandardLoggingPayload] = cast(
Optional[StandardLoggingPayload],
kwargs.get("standard_logging_object", None),
)
if standard_logging_object is None:
end_user_id = None
else:
end_user_id = standard_logging_object["metadata"].get(
"user_api_key_end_user_id", None
)
# Clean Metadata before logging - never log raw metadata # Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion # the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging # we clean out all extra litellm metadata params before logging
@ -541,7 +553,7 @@ class LangFuseLogger:
"version": clean_metadata.pop( "version": clean_metadata.pop(
"trace_version", clean_metadata.get("version", None) "trace_version", clean_metadata.get("version", None)
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
"user_id": user_id, "user_id": end_user_id,
} }
for key in list( for key in list(
filter(lambda key: key.startswith("trace_"), clean_metadata.keys()) filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
@ -567,10 +579,6 @@ class LangFuseLogger:
cost = kwargs.get("response_cost", None) cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}") print_verbose(f"trace: {cost}")
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
clean_metadata["litellm_response_cost"] = cost clean_metadata["litellm_response_cost"] = cost
if standard_logging_object is not None: if standard_logging_object is not None:
clean_metadata["hidden_params"] = standard_logging_object[ clean_metadata["hidden_params"] = standard_logging_object[

View file

@ -2619,6 +2619,7 @@ class StandardLoggingPayloadSetup:
spend_logs_metadata=None, spend_logs_metadata=None,
requester_ip_address=None, requester_ip_address=None,
requester_metadata=None, requester_metadata=None,
user_api_key_end_user_id=None,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys # Filter the metadata dictionary to include only the specified keys
@ -3075,6 +3076,7 @@ def get_standard_logging_metadata(
spend_logs_metadata=None, spend_logs_metadata=None,
requester_ip_address=None, requester_ip_address=None,
requester_metadata=None, requester_metadata=None,
user_api_key_end_user_id=None,
) )
if isinstance(metadata, dict): if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys # Filter the metadata dictionary to include only the specified keys

View file

@ -20,9 +20,4 @@ model_list:
api_version: "2024-05-01-preview" api_version: "2024-05-01-preview"
litellm_settings: litellm_settings:
default_team_settings: success_callback: ["langfuse"]
- team_id: c91e32bb-0f2a-4aa1-86c4-307ca2e03ea3
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
langfuse_public_key: my-fake-key
langfuse_secret: my-fake-secret

View file

@ -608,8 +608,12 @@ async def user_api_key_auth( # noqa: PLR0915
end_user_params = {} end_user_params = {}
if "user" in request_data: if "user" in request_data:
try: try:
end_user_id = request_data["user"]
end_user_params["end_user_id"] = end_user_id
# get end-user object
_end_user_object = await get_end_user_object( _end_user_object = await get_end_user_object(
end_user_id=request_data["user"], end_user_id=end_user_id,
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
@ -621,7 +625,6 @@ async def user_api_key_auth( # noqa: PLR0915
) )
if _end_user_object.litellm_budget_table is not None: if _end_user_object.litellm_budget_table is not None:
budget_info = _end_user_object.litellm_budget_table budget_info = _end_user_object.litellm_budget_table
end_user_params["end_user_id"] = _end_user_object.user_id
if budget_info.tpm_limit is not None: if budget_info.tpm_limit is not None:
end_user_params["end_user_tpm_limit"] = ( end_user_params["end_user_tpm_limit"] = (
budget_info.tpm_limit budget_info.tpm_limit

View file

@ -60,7 +60,7 @@ def _safe_get_request_headers(request: Optional[Request]) -> dict:
return {} return {}
return dict(request.headers) return dict(request.headers)
except Exception as e: except Exception as e:
verbose_proxy_logger.exception( verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e) "Unexpected error reading request headers - {}".format(e)
) )
return {} return {}

View file

@ -1,6 +1,6 @@
import asyncio import asyncio
import traceback import traceback
from typing import Optional, Union from typing import Optional, Union, cast
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -36,10 +36,10 @@ async def _PROXY_track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params) end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None) user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
team_id = metadata.get("user_api_key_team_id", None) team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
org_id = metadata.get("user_api_key_org_id", None) org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
key_alias = metadata.get("user_api_key_alias", None) key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get( sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None "standard_logging_object", None
@ -61,7 +61,12 @@ async def _PROXY_track_cost_callback(
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}" f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
) )
if user_api_key is not None or user_id is not None or team_id is not None: if _should_track_cost_callback(
user_api_key=user_api_key,
user_id=user_id,
team_id=team_id,
end_user_id=end_user_id,
):
## UPDATE DATABASE ## UPDATE DATABASE
await update_database( await update_database(
token=user_api_key, token=user_api_key,
@ -128,3 +133,22 @@ async def _PROXY_track_cost_callback(
) )
) )
verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e)) verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e))
def _should_track_cost_callback(
user_api_key: Optional[str],
user_id: Optional[str],
team_id: Optional[str],
end_user_id: Optional[str],
) -> bool:
"""
Determine if the cost callback should be tracked based on the kwargs
"""
if (
user_api_key is not None
or user_id is not None
or team_id is not None
or end_user_id is not None
):
return True
return False

View file

@ -72,8 +72,12 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
query_params = dict(request.query_params) query_params = dict(request.query_params)
if "api-version" in query_params: if "api-version" in query_params:
data["api_version"] = query_params["api-version"] data["api_version"] = query_params["api-version"]
except KeyError:
pass
except Exception as e: except Exception as e:
verbose_logger.error("error checking api version in query params: %s", str(e)) verbose_logger.exception(
"error checking api version in query params: %s", str(e)
)
def convert_key_logging_metadata_to_callback( def convert_key_logging_metadata_to_callback(
@ -266,6 +270,7 @@ class LiteLLMProxyRequestSetup:
user_api_key_user_id=user_api_key_dict.user_id, user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_org_id=user_api_key_dict.org_id, user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias, user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
) )
return user_api_key_logged_metadata return user_api_key_logged_metadata

View file

@ -1420,6 +1420,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
user_api_key_team_id: Optional[str] user_api_key_team_id: Optional[str]
user_api_key_user_id: Optional[str] user_api_key_user_id: Optional[str]
user_api_key_team_alias: Optional[str] user_api_key_team_alias: Optional[str]
user_api_key_end_user_id: Optional[str]
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):

View file

@ -6269,7 +6269,13 @@ def get_end_user_id_for_cost_tracking(
service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking. service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
""" """
proxy_server_request = litellm_params.get("proxy_server_request") or {} _metadata = cast(dict, litellm_params.get("metadata", {}) or {})
end_user_id = cast(
Optional[str],
litellm_params.get("user_api_key_end_user_id")
or _metadata.get("user_api_key_end_user_id"),
)
if litellm.disable_end_user_cost_tracking: if litellm.disable_end_user_cost_tracking:
return None return None
if ( if (
@ -6277,7 +6283,7 @@ def get_end_user_id_for_cost_tracking(
and litellm.disable_end_user_cost_tracking_prometheus_only and litellm.disable_end_user_cost_tracking_prometheus_only
): ):
return None return None
return proxy_server_request.get("body", {}).get("user", None) return end_user_id
def is_prompt_caching_valid_prompt( def is_prompt_caching_valid_prompt(

View file

@ -1113,8 +1113,8 @@ def test_models_by_provider():
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id", "litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
[ [
({}, False, None), ({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), ({"user_api_key_end_user_id": "123"}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None), ({"user_api_key_end_user_id": "123"}, True, None),
], ],
) )
def test_get_end_user_id_for_cost_tracking( def test_get_end_user_id_for_cost_tracking(
@ -1133,8 +1133,8 @@ def test_get_end_user_id_for_cost_tracking(
"litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id", "litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id",
[ [
({}, False, None), ({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), ({"user_api_key_end_user_id": "123"}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None), ({"user_api_key_end_user_id": "123"}, True, None),
], ],
) )
def test_get_end_user_id_for_cost_tracking_prometheus_only( def test_get_end_user_id_for_cost_tracking_prometheus_only(

View file

@ -270,6 +270,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_alias", "metadata.user_api_key_alias",
"metadata.user_api_key_user_id", "metadata.user_api_key_user_id",
"metadata.user_api_key_org_id", "metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
] ]
_all_attributes = set( _all_attributes = set(

View file

@ -22,7 +22,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from fastapi import Request from fastapi import Request
from fastapi.routing import APIRoute
from fastapi.responses import Response
import litellm import litellm
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
@ -1035,6 +1036,7 @@ async def test_end_user_jwt_auth(monkeypatch):
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.proxy_server import user_api_key_auth from litellm.proxy.proxy_server import user_api_key_auth
import json
monkeypatch.delenv("JWT_AUDIENCE", None) monkeypatch.delenv("JWT_AUDIENCE", None)
jwt_handler = JWTHandler() jwt_handler = JWTHandler()
@ -1094,21 +1096,70 @@ async def test_end_user_jwt_auth(monkeypatch):
bearer_token = "Bearer " + token bearer_token = "Bearer " + token
request = Request(scope={"type": "http"}) api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
request._url = URL(url="/chat/completions") request = Request(
{
"type": "http",
"route": api_route,
"path": "/chat/completions",
"headers": [(b"authorization", f"Bearer {bearer_token}".encode("latin-1"))],
"method": "POST",
}
)
async def return_body():
body_dict = {
"model": "openai/gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
# Serialize the dictionary to JSON and encode it to bytes
return json.dumps(body_dict).encode("utf-8")
request.body = return_body
## 1. INITIAL TEAM CALL - should fail ## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in # use generated key to auth in
setattr( setattr(
litellm.proxy.proxy_server, litellm.proxy.proxy_server,
"general_settings", "general_settings",
{ {"enable_jwt_auth": True, "pass_through_all_models": True},
"enable_jwt_auth": True, )
}, setattr(
litellm.proxy.proxy_server,
"llm_router",
MagicMock(),
) )
setattr(litellm.proxy.proxy_server, "prisma_client", {}) setattr(litellm.proxy.proxy_server, "prisma_client", {})
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
from litellm.proxy.proxy_server import cost_tracking
cost_tracking()
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
assert ( assert (
result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479" result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479"
) # jwt token decoded sub value ) # jwt token decoded sub value
temp_response = Response()
from litellm.proxy.hooks.proxy_track_cost_callback import (
_should_track_cost_callback,
)
with patch.object(
litellm.proxy.hooks.proxy_track_cost_callback, "_should_track_cost_callback"
) as mock_client:
resp = await chat_completion(
request=request,
fastapi_response=temp_response,
model="gpt-4o",
user_api_key_dict=result,
)
assert resp is not None
await asyncio.sleep(1)
mock_client.assert_called_once()
mock_client.call_args.kwargs[
"end_user_id"
] == "81b3e52a-67a6-4efb-9645-70527e101479"

View file

@ -588,8 +588,6 @@ def test_call_with_end_user_over_budget(prisma_client):
request._url = URL(url="/chat/completions") request._url = URL(url="/chat/completions")
bearer_token = "Bearer sk-1234" bearer_token = "Bearer sk-1234"
result = await user_api_key_auth(request=request, api_key=bearer_token)
async def return_body(): async def return_body():
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}' return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}'
# return string as bytes # return string as bytes
@ -597,6 +595,8 @@ def test_call_with_end_user_over_budget(prisma_client):
request.body = return_body request.body = return_body
result = await user_api_key_auth(request=request, api_key=bearer_token)
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm import Choices, Message, ModelResponse, Usage from litellm import Choices, Message, ModelResponse, Usage
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
@ -624,7 +624,7 @@ def test_call_with_end_user_over_budget(prisma_client):
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": "sk-1234", "user_api_key": "sk-1234",
"user_api_key_user_id": user, "user_api_key_end_user_id": user,
}, },
"proxy_server_request": { "proxy_server_request": {
"body": { "body": {
@ -653,6 +653,7 @@ def test_call_with_end_user_over_budget(prisma_client):
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
print(f"raised error: {e}, traceback: {traceback.format_exc()}")
error_detail = e.message error_detail = e.message
assert "Budget has been exceeded! Current" in error_detail assert "Budget has been exceeded! Current" in error_detail
assert isinstance(e, ProxyException) assert isinstance(e, ProxyException)
@ -3634,3 +3635,19 @@ async def test_enforce_unique_key_alias(prisma_client):
except Exception as e: except Exception as e:
print("Unexpected error:", e) print("Unexpected error:", e)
pytest.fail(f"An unexpected error occurred: {str(e)}") pytest.fail(f"An unexpected error occurred: {str(e)}")
def test_should_track_cost_callback():
"""
Test that the should_track_cost_callback function works as expected
"""
from litellm.proxy.hooks.proxy_track_cost_callback import (
_should_track_cost_callback,
)
assert _should_track_cost_callback(
user_api_key=None,
user_id=None,
team_id=None,
end_user_id="1234",
)