diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index 888fcde887..2a67797785 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -458,6 +458,18 @@ class LangFuseLogger: tags = metadata.pop("tags", []) if supports_tags else [] + standard_logging_object: Optional[StandardLoggingPayload] = cast( + Optional[StandardLoggingPayload], + kwargs.get("standard_logging_object", None), + ) + + if standard_logging_object is None: + end_user_id = None + else: + end_user_id = standard_logging_object["metadata"].get( + "user_api_key_end_user_id", None + ) + # Clean Metadata before logging - never log raw metadata # the raw metadata can contain circular references which leads to infinite recursion # we clean out all extra litellm metadata params before logging @@ -541,7 +553,7 @@ class LangFuseLogger: "version": clean_metadata.pop( "trace_version", clean_metadata.get("version", None) ), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence - "user_id": user_id, + "user_id": end_user_id, } for key in list( filter(lambda key: key.startswith("trace_"), clean_metadata.keys()) @@ -567,10 +579,6 @@ class LangFuseLogger: cost = kwargs.get("response_cost", None) print_verbose(f"trace: {cost}") - standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object", None - ) - clean_metadata["litellm_response_cost"] = cost if standard_logging_object is not None: clean_metadata["hidden_params"] = standard_logging_object[ diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index a2fe21a680..1e3d374e15 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2619,6 +2619,7 @@ class StandardLoggingPayloadSetup: spend_logs_metadata=None, requester_ip_address=None, requester_metadata=None, + user_api_key_end_user_id=None, ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys @@ -3075,6 +3076,7 @@ def get_standard_logging_metadata( spend_logs_metadata=None, requester_ip_address=None, requester_metadata=None, + user_api_key_end_user_id=None, ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 4d9879c37a..a46e8d1eaf 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -20,9 +20,4 @@ model_list: api_version: "2024-05-01-preview" litellm_settings: - default_team_settings: - - team_id: c91e32bb-0f2a-4aa1-86c4-307ca2e03ea3 - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - langfuse_public_key: my-fake-key - langfuse_secret: my-fake-secret + success_callback: ["langfuse"] \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 8f82cdcf81..9bec7884e9 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -608,8 +608,12 @@ async def user_api_key_auth( # noqa: PLR0915 end_user_params = {} if "user" in request_data: try: + end_user_id = request_data["user"] + end_user_params["end_user_id"] = end_user_id + + # get end-user object _end_user_object = await get_end_user_object( - end_user_id=request_data["user"], + end_user_id=end_user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, @@ -621,7 +625,6 @@ async def user_api_key_auth( # noqa: PLR0915 ) if _end_user_object.litellm_budget_table is not None: budget_info = _end_user_object.litellm_budget_table - end_user_params["end_user_id"] = _end_user_object.user_id if budget_info.tpm_limit is not None: end_user_params["end_user_tpm_limit"] = ( budget_info.tpm_limit diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 16220a418b..1fb0da275e 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -60,7 +60,7 @@ def _safe_get_request_headers(request: Optional[Request]) -> dict: return {} return dict(request.headers) except Exception as e: - verbose_proxy_logger.exception( + verbose_proxy_logger.debug( "Unexpected error reading request headers - {}".format(e) ) return {} diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 20a276ac62..33e9341706 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -1,6 +1,6 @@ import asyncio import traceback -from typing import Optional, Union +from typing import Optional, Union, cast import litellm from litellm._logging import verbose_proxy_logger @@ -36,10 +36,10 @@ async def _PROXY_track_cost_callback( litellm_params = kwargs.get("litellm_params", {}) or {} end_user_id = get_end_user_id_for_cost_tracking(litellm_params) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - user_id = metadata.get("user_api_key_user_id", None) - team_id = metadata.get("user_api_key_team_id", None) - org_id = metadata.get("user_api_key_org_id", None) - key_alias = metadata.get("user_api_key_alias", None) + user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None)) + team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None)) + org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None)) + key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None)) end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) sl_object: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None @@ -61,7 +61,12 @@ async def _PROXY_track_cost_callback( verbose_proxy_logger.debug( f"user_api_key {user_api_key}, prisma_client: {prisma_client}" ) - if user_api_key is not None or user_id is not None or team_id is not None: + if _should_track_cost_callback( + user_api_key=user_api_key, + user_id=user_id, + team_id=team_id, + end_user_id=end_user_id, + ): ## UPDATE DATABASE await update_database( token=user_api_key, @@ -128,3 +133,22 @@ async def _PROXY_track_cost_callback( ) ) verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e)) + + +def _should_track_cost_callback( + user_api_key: Optional[str], + user_id: Optional[str], + team_id: Optional[str], + end_user_id: Optional[str], +) -> bool: + """ + Determine if the cost callback should be tracked based on the kwargs + """ + if ( + user_api_key is not None + or user_id is not None + or team_id is not None + or end_user_id is not None + ): + return True + return False diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index b361eeeeab..aadeff0612 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -72,8 +72,12 @@ def safe_add_api_version_from_query_params(data: dict, request: Request): query_params = dict(request.query_params) if "api-version" in query_params: data["api_version"] = query_params["api-version"] + except KeyError: + pass except Exception as e: - verbose_logger.error("error checking api version in query params: %s", str(e)) + verbose_logger.exception( + "error checking api version in query params: %s", str(e) + ) def convert_key_logging_metadata_to_callback( @@ -266,6 +270,7 @@ class LiteLLMProxyRequestSetup: user_api_key_user_id=user_api_key_dict.user_id, user_api_key_org_id=user_api_key_dict.org_id, user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, ) return user_api_key_logged_metadata diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 254395ea51..1d66c86d62 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1420,6 +1420,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict): user_api_key_team_id: Optional[str] user_api_key_user_id: Optional[str] user_api_key_team_alias: Optional[str] + user_api_key_end_user_id: Optional[str] class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): diff --git a/litellm/utils.py b/litellm/utils.py index 65e846cb3b..239a5e8bb9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6269,7 +6269,13 @@ def get_end_user_id_for_cost_tracking( service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking. """ - proxy_server_request = litellm_params.get("proxy_server_request") or {} + _metadata = cast(dict, litellm_params.get("metadata", {}) or {}) + + end_user_id = cast( + Optional[str], + litellm_params.get("user_api_key_end_user_id") + or _metadata.get("user_api_key_end_user_id"), + ) if litellm.disable_end_user_cost_tracking: return None if ( @@ -6277,7 +6283,7 @@ def get_end_user_id_for_cost_tracking( and litellm.disable_end_user_cost_tracking_prometheus_only ): return None - return proxy_server_request.get("body", {}).get("user", None) + return end_user_id def is_prompt_caching_valid_prompt( diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index a14bc57061..7240d0dba4 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1113,8 +1113,8 @@ def test_models_by_provider(): "litellm_params, disable_end_user_cost_tracking, expected_end_user_id", [ ({}, False, None), - ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), - ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ({"user_api_key_end_user_id": "123"}, False, "123"), + ({"user_api_key_end_user_id": "123"}, True, None), ], ) def test_get_end_user_id_for_cost_tracking( @@ -1133,8 +1133,8 @@ def test_get_end_user_id_for_cost_tracking( "litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id", [ ({}, False, None), - ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), - ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ({"user_api_key_end_user_id": "123"}, False, "123"), + ({"user_api_key_end_user_id": "123"}, True, None), ], ) def test_get_end_user_id_for_cost_tracking_prometheus_only( diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py index b35f0fe9cb..a9a11cfba2 100644 --- a/tests/logging_callback_tests/test_otel_logging.py +++ b/tests/logging_callback_tests/test_otel_logging.py @@ -270,6 +270,7 @@ def validate_redacted_message_span_attributes(span): "metadata.user_api_key_alias", "metadata.user_api_key_user_id", "metadata.user_api_key_org_id", + "metadata.user_api_key_end_user_id", ] _all_attributes = set( diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index c68878dcbe..1c6ca72a1f 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -22,7 +22,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request - +from fastapi.routing import APIRoute +from fastapi.responses import Response import litellm from litellm.caching.caching import DualCache from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable, LiteLLMRoutes @@ -1035,6 +1036,7 @@ async def test_end_user_jwt_auth(monkeypatch): from litellm.caching import DualCache from litellm.proxy._types import LiteLLM_JWTAuth from litellm.proxy.proxy_server import user_api_key_auth + import json monkeypatch.delenv("JWT_AUDIENCE", None) jwt_handler = JWTHandler() @@ -1094,21 +1096,70 @@ async def test_end_user_jwt_auth(monkeypatch): bearer_token = "Bearer " + token - request = Request(scope={"type": "http"}) - request._url = URL(url="/chat/completions") + api_route = APIRoute(path="/chat/completions", endpoint=chat_completion) + request = Request( + { + "type": "http", + "route": api_route, + "path": "/chat/completions", + "headers": [(b"authorization", f"Bearer {bearer_token}".encode("latin-1"))], + "method": "POST", + } + ) + + async def return_body(): + body_dict = { + "model": "openai/gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + } + # Serialize the dictionary to JSON and encode it to bytes + return json.dumps(body_dict).encode("utf-8") + + request.body = return_body ## 1. INITIAL TEAM CALL - should fail # use generated key to auth in setattr( litellm.proxy.proxy_server, "general_settings", - { - "enable_jwt_auth": True, - }, + {"enable_jwt_auth": True, "pass_through_all_models": True}, + ) + setattr( + litellm.proxy.proxy_server, + "llm_router", + MagicMock(), ) setattr(litellm.proxy.proxy_server, "prisma_client", {}) setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + from litellm.proxy.proxy_server import cost_tracking + + cost_tracking() result = await user_api_key_auth(request=request, api_key=bearer_token) assert ( result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479" ) # jwt token decoded sub value + + temp_response = Response() + from litellm.proxy.hooks.proxy_track_cost_callback import ( + _should_track_cost_callback, + ) + + with patch.object( + litellm.proxy.hooks.proxy_track_cost_callback, "_should_track_cost_callback" + ) as mock_client: + resp = await chat_completion( + request=request, + fastapi_response=temp_response, + model="gpt-4o", + user_api_key_dict=result, + ) + + assert resp is not None + + await asyncio.sleep(1) + + mock_client.assert_called_once() + + mock_client.call_args.kwargs[ + "end_user_id" + ] == "81b3e52a-67a6-4efb-9645-70527e101479" diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py index c745f9dd96..d2b0c765d5 100644 --- a/tests/proxy_unit_tests/test_key_generate_prisma.py +++ b/tests/proxy_unit_tests/test_key_generate_prisma.py @@ -588,8 +588,6 @@ def test_call_with_end_user_over_budget(prisma_client): request._url = URL(url="/chat/completions") bearer_token = "Bearer sk-1234" - result = await user_api_key_auth(request=request, api_key=bearer_token) - async def return_body(): return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}' # return string as bytes @@ -597,6 +595,8 @@ def test_call_with_end_user_over_budget(prisma_client): request.body = return_body + result = await user_api_key_auth(request=request, api_key=bearer_token) + # update spend using track_cost callback, make 2nd request, it should fail from litellm import Choices, Message, ModelResponse, Usage from litellm.proxy.proxy_server import ( @@ -624,7 +624,7 @@ def test_call_with_end_user_over_budget(prisma_client): "litellm_params": { "metadata": { "user_api_key": "sk-1234", - "user_api_key_user_id": user, + "user_api_key_end_user_id": user, }, "proxy_server_request": { "body": { @@ -653,6 +653,7 @@ def test_call_with_end_user_over_budget(prisma_client): asyncio.run(test()) except Exception as e: + print(f"raised error: {e}, traceback: {traceback.format_exc()}") error_detail = e.message assert "Budget has been exceeded! Current" in error_detail assert isinstance(e, ProxyException) @@ -3634,3 +3635,19 @@ async def test_enforce_unique_key_alias(prisma_client): except Exception as e: print("Unexpected error:", e) pytest.fail(f"An unexpected error occurred: {str(e)}") + + +def test_should_track_cost_callback(): + """ + Test that the should_track_cost_callback function works as expected + """ + from litellm.proxy.hooks.proxy_track_cost_callback import ( + _should_track_cost_callback, + ) + + assert _should_track_cost_callback( + user_api_key=None, + user_id=None, + team_id=None, + end_user_id="1234", + )