mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Litellm dev 12 20 2024 p3 (#7339)
* fix(proxy_track_cost_callback.py): log to db if only end user param given * fix: allows for jwt-auth based end user id spend tracking to work * fix(utils.py): fix 'get_end_user_id_for_cost_tracking' to use 'user_api_key_end_user_id' more stable - works with jwt-auth based end user tracking as well * test(test_jwt.py): add e2e unit test to confirm end user cost tracking works for spend logs * test: update test to use end_user api key hash param * fix(langfuse.py): support end user cost tracking via jwt auth + langfuse logs end user to langfuse if decoded from jwt token * fix: fix linting errors * test: fix test * test: fix test * fix: fix end user id extraction * fix: run test earlier
This commit is contained in:
parent
1b2ed0c344
commit
61b4c41c3c
13 changed files with 149 additions and 36 deletions
|
@ -458,6 +458,18 @@ class LangFuseLogger:
|
|||
|
||||
tags = metadata.pop("tags", []) if supports_tags else []
|
||||
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = cast(
|
||||
Optional[StandardLoggingPayload],
|
||||
kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
end_user_id = None
|
||||
else:
|
||||
end_user_id = standard_logging_object["metadata"].get(
|
||||
"user_api_key_end_user_id", None
|
||||
)
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
|
@ -541,7 +553,7 @@ class LangFuseLogger:
|
|||
"version": clean_metadata.pop(
|
||||
"trace_version", clean_metadata.get("version", None)
|
||||
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||
"user_id": user_id,
|
||||
"user_id": end_user_id,
|
||||
}
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
|
@ -567,10 +579,6 @@ class LangFuseLogger:
|
|||
cost = kwargs.get("response_cost", None)
|
||||
print_verbose(f"trace: {cost}")
|
||||
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
clean_metadata["litellm_response_cost"] = cost
|
||||
if standard_logging_object is not None:
|
||||
clean_metadata["hidden_params"] = standard_logging_object[
|
||||
|
|
|
@ -2619,6 +2619,7 @@ class StandardLoggingPayloadSetup:
|
|||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
user_api_key_end_user_id=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
@ -3075,6 +3076,7 @@ def get_standard_logging_metadata(
|
|||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
user_api_key_end_user_id=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
|
|
@ -20,9 +20,4 @@ model_list:
|
|||
api_version: "2024-05-01-preview"
|
||||
|
||||
litellm_settings:
|
||||
default_team_settings:
|
||||
- team_id: c91e32bb-0f2a-4aa1-86c4-307ca2e03ea3
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
langfuse_public_key: my-fake-key
|
||||
langfuse_secret: my-fake-secret
|
||||
|
|
|
@ -608,8 +608,12 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
end_user_params = {}
|
||||
if "user" in request_data:
|
||||
try:
|
||||
end_user_id = request_data["user"]
|
||||
end_user_params["end_user_id"] = end_user_id
|
||||
|
||||
# get end-user object
|
||||
_end_user_object = await get_end_user_object(
|
||||
end_user_id=request_data["user"],
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
|
@ -621,7 +625,6 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
)
|
||||
if _end_user_object.litellm_budget_table is not None:
|
||||
budget_info = _end_user_object.litellm_budget_table
|
||||
end_user_params["end_user_id"] = _end_user_object.user_id
|
||||
if budget_info.tpm_limit is not None:
|
||||
end_user_params["end_user_tpm_limit"] = (
|
||||
budget_info.tpm_limit
|
||||
|
|
|
@ -60,7 +60,7 @@ def _safe_get_request_headers(request: Optional[Request]) -> dict:
|
|||
return {}
|
||||
return dict(request.headers)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
verbose_proxy_logger.debug(
|
||||
"Unexpected error reading request headers - {}".format(e)
|
||||
)
|
||||
return {}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -36,10 +36,10 @@ async def _PROXY_track_cost_callback(
|
|||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = metadata.get("user_api_key_user_id", None)
|
||||
team_id = metadata.get("user_api_key_team_id", None)
|
||||
org_id = metadata.get("user_api_key_org_id", None)
|
||||
key_alias = metadata.get("user_api_key_alias", None)
|
||||
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
|
||||
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
|
||||
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
|
||||
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
|
@ -61,7 +61,12 @@ async def _PROXY_track_cost_callback(
|
|||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
|
||||
)
|
||||
if user_api_key is not None or user_id is not None or team_id is not None:
|
||||
if _should_track_cost_callback(
|
||||
user_api_key=user_api_key,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
end_user_id=end_user_id,
|
||||
):
|
||||
## UPDATE DATABASE
|
||||
await update_database(
|
||||
token=user_api_key,
|
||||
|
@ -128,3 +133,22 @@ async def _PROXY_track_cost_callback(
|
|||
)
|
||||
)
|
||||
verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e))
|
||||
|
||||
|
||||
def _should_track_cost_callback(
|
||||
user_api_key: Optional[str],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the cost callback should be tracked based on the kwargs
|
||||
"""
|
||||
if (
|
||||
user_api_key is not None
|
||||
or user_id is not None
|
||||
or team_id is not None
|
||||
or end_user_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -72,8 +72,12 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
|
|||
query_params = dict(request.query_params)
|
||||
if "api-version" in query_params:
|
||||
data["api_version"] = query_params["api-version"]
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_logger.error("error checking api version in query params: %s", str(e))
|
||||
verbose_logger.exception(
|
||||
"error checking api version in query params: %s", str(e)
|
||||
)
|
||||
|
||||
|
||||
def convert_key_logging_metadata_to_callback(
|
||||
|
@ -266,6 +270,7 @@ class LiteLLMProxyRequestSetup:
|
|||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
)
|
||||
return user_api_key_logged_metadata
|
||||
|
||||
|
|
|
@ -1420,6 +1420,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
|||
user_api_key_team_id: Optional[str]
|
||||
user_api_key_user_id: Optional[str]
|
||||
user_api_key_team_alias: Optional[str]
|
||||
user_api_key_end_user_id: Optional[str]
|
||||
|
||||
|
||||
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||
|
|
|
@ -6269,7 +6269,13 @@ def get_end_user_id_for_cost_tracking(
|
|||
|
||||
service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
|
||||
"""
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
_metadata = cast(dict, litellm_params.get("metadata", {}) or {})
|
||||
|
||||
end_user_id = cast(
|
||||
Optional[str],
|
||||
litellm_params.get("user_api_key_end_user_id")
|
||||
or _metadata.get("user_api_key_end_user_id"),
|
||||
)
|
||||
if litellm.disable_end_user_cost_tracking:
|
||||
return None
|
||||
if (
|
||||
|
@ -6277,7 +6283,7 @@ def get_end_user_id_for_cost_tracking(
|
|||
and litellm.disable_end_user_cost_tracking_prometheus_only
|
||||
):
|
||||
return None
|
||||
return proxy_server_request.get("body", {}).get("user", None)
|
||||
return end_user_id
|
||||
|
||||
|
||||
def is_prompt_caching_valid_prompt(
|
||||
|
|
|
@ -1113,8 +1113,8 @@ def test_models_by_provider():
|
|||
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
|
||||
[
|
||||
({}, False, None),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
|
||||
({"user_api_key_end_user_id": "123"}, False, "123"),
|
||||
({"user_api_key_end_user_id": "123"}, True, None),
|
||||
],
|
||||
)
|
||||
def test_get_end_user_id_for_cost_tracking(
|
||||
|
@ -1133,8 +1133,8 @@ def test_get_end_user_id_for_cost_tracking(
|
|||
"litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id",
|
||||
[
|
||||
({}, False, None),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
|
||||
({"user_api_key_end_user_id": "123"}, False, "123"),
|
||||
({"user_api_key_end_user_id": "123"}, True, None),
|
||||
],
|
||||
)
|
||||
def test_get_end_user_id_for_cost_tracking_prometheus_only(
|
||||
|
|
|
@ -270,6 +270,7 @@ def validate_redacted_message_span_attributes(span):
|
|||
"metadata.user_api_key_alias",
|
||||
"metadata.user_api_key_user_id",
|
||||
"metadata.user_api_key_org_id",
|
||||
"metadata.user_api_key_end_user_id",
|
||||
]
|
||||
|
||||
_all_attributes = set(
|
||||
|
|
|
@ -22,7 +22,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.responses import Response
|
||||
import litellm
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes
|
||||
|
@ -1035,6 +1036,7 @@ async def test_end_user_jwt_auth(monkeypatch):
|
|||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.proxy_server import user_api_key_auth
|
||||
import json
|
||||
|
||||
monkeypatch.delenv("JWT_AUDIENCE", None)
|
||||
jwt_handler = JWTHandler()
|
||||
|
@ -1094,21 +1096,70 @@ async def test_end_user_jwt_auth(monkeypatch):
|
|||
|
||||
bearer_token = "Bearer " + token
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"route": api_route,
|
||||
"path": "/chat/completions",
|
||||
"headers": [(b"authorization", f"Bearer {bearer_token}".encode("latin-1"))],
|
||||
"method": "POST",
|
||||
}
|
||||
)
|
||||
|
||||
async def return_body():
|
||||
body_dict = {
|
||||
"model": "openai/gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
}
|
||||
# Serialize the dictionary to JSON and encode it to bytes
|
||||
return json.dumps(body_dict).encode("utf-8")
|
||||
|
||||
request.body = return_body
|
||||
|
||||
## 1. INITIAL TEAM CALL - should fail
|
||||
# use generated key to auth in
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{
|
||||
"enable_jwt_auth": True,
|
||||
},
|
||||
{"enable_jwt_auth": True, "pass_through_all_models": True},
|
||||
)
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"llm_router",
|
||||
MagicMock(),
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", {})
|
||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||
from litellm.proxy.proxy_server import cost_tracking
|
||||
|
||||
cost_tracking()
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
assert (
|
||||
result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479"
|
||||
) # jwt token decoded sub value
|
||||
|
||||
temp_response = Response()
|
||||
from litellm.proxy.hooks.proxy_track_cost_callback import (
|
||||
_should_track_cost_callback,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
litellm.proxy.hooks.proxy_track_cost_callback, "_should_track_cost_callback"
|
||||
) as mock_client:
|
||||
resp = await chat_completion(
|
||||
request=request,
|
||||
fastapi_response=temp_response,
|
||||
model="gpt-4o",
|
||||
user_api_key_dict=result,
|
||||
)
|
||||
|
||||
assert resp is not None
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
mock_client.call_args.kwargs[
|
||||
"end_user_id"
|
||||
] == "81b3e52a-67a6-4efb-9645-70527e101479"
|
||||
|
|
|
@ -588,8 +588,6 @@ def test_call_with_end_user_over_budget(prisma_client):
|
|||
request._url = URL(url="/chat/completions")
|
||||
bearer_token = "Bearer sk-1234"
|
||||
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
|
||||
async def return_body():
|
||||
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}'
|
||||
# return string as bytes
|
||||
|
@ -597,6 +595,8 @@ def test_call_with_end_user_over_budget(prisma_client):
|
|||
|
||||
request.body = return_body
|
||||
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
|
||||
# update spend using track_cost callback, make 2nd request, it should fail
|
||||
from litellm import Choices, Message, ModelResponse, Usage
|
||||
from litellm.proxy.proxy_server import (
|
||||
|
@ -624,7 +624,7 @@ def test_call_with_end_user_over_budget(prisma_client):
|
|||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": "sk-1234",
|
||||
"user_api_key_user_id": user,
|
||||
"user_api_key_end_user_id": user,
|
||||
},
|
||||
"proxy_server_request": {
|
||||
"body": {
|
||||
|
@ -653,6 +653,7 @@ def test_call_with_end_user_over_budget(prisma_client):
|
|||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
print(f"raised error: {e}, traceback: {traceback.format_exc()}")
|
||||
error_detail = e.message
|
||||
assert "Budget has been exceeded! Current" in error_detail
|
||||
assert isinstance(e, ProxyException)
|
||||
|
@ -3634,3 +3635,19 @@ async def test_enforce_unique_key_alias(prisma_client):
|
|||
except Exception as e:
|
||||
print("Unexpected error:", e)
|
||||
pytest.fail(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
def test_should_track_cost_callback():
|
||||
"""
|
||||
Test that the should_track_cost_callback function works as expected
|
||||
"""
|
||||
from litellm.proxy.hooks.proxy_track_cost_callback import (
|
||||
_should_track_cost_callback,
|
||||
)
|
||||
|
||||
assert _should_track_cost_callback(
|
||||
user_api_key=None,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
end_user_id="1234",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue