diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py new file mode 100644 index 000000000..28518a68a --- /dev/null +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -0,0 +1,136 @@ +""" +Proxy Success Callback - handles storing cost of a request in LiteLLM DB. + +Updates cost for the following in LiteLLM DB: + - spend logs + - virtual key spend + - internal user, team, external user spend +""" + +import asyncio +import traceback +from typing import Optional + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.proxy.utils import log_db_metrics +from litellm.types.utils import StandardLoggingPayload + + +@log_db_metrics +async def _PROXY_track_cost_callback( + kwargs, # kwargs to completion + completion_response: litellm.ModelResponse, # response from completion + start_time=None, + end_time=None, # start/end time for completion +): + """ + Callback handles storing cost of a request in LiteLLM DB. + + Updates cost for the following in LiteLLM DB: + - spend logs + - virtual key spend + - internal user, team, external user spend + """ + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + update_cache, + update_database, + ) + + verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") + try: + # check if it has collected an entire stream response + verbose_proxy_logger.debug( + "Proxy: In track_cost_callback for: kwargs=%s and completion_response: %s", + kwargs, + completion_response, + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError( + "standard_logging_payload is none in kwargs, cannot track cost without it" + ) + end_user_id = standard_logging_payload.get("end_user") + user_api_key = standard_logging_payload.get("metadata", {}).get( + "user_api_key_hash" + ) + user_id = standard_logging_payload.get("metadata", {}).get( + "user_api_key_user_id" + ) + team_id = standard_logging_payload.get("metadata", {}).get( + "user_api_key_team_id" + ) + org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id") + key_alias = standard_logging_payload.get("metadata", {}).get( + "user_api_key_alias" + ) + end_user_max_budget = standard_logging_payload.get("metadata", {}).get( + "user_api_end_user_max_budget" + ) + response_cost: Optional[float] = standard_logging_payload.get("response_cost") + + if response_cost is not None: + if user_api_key is not None or user_id is not None or team_id is not None: + ## UPDATE DATABASE + await update_database( + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + end_user_id=end_user_id, + team_id=team_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + # update cache + asyncio.create_task( + update_cache( + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, + team_id=team_id, + parent_otel_span=parent_otel_span, + ) + ) + + await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( + token=user_api_key, + key_alias=key_alias, + end_user_id=end_user_id, + response_cost=response_cost, + max_budget=end_user_max_budget, + ) + else: + raise Exception( + "User API key and team id and user id missing from custom callback." + ) + else: + cost_tracking_failure_debug_info = standard_logging_payload.get( + "response_cost_failure_debug_info" + ) + model = kwargs.get("model") + raise ValueError( + f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) + except Exception as e: + error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" + model = kwargs.get("model", "") + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n" + asyncio.create_task( + proxy_logging_obj.failed_tracking_alert( + error_message=error_msg, + failing_model=model, + ) + ) + verbose_proxy_logger.debug("error in tracking cost callback - %s", e) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c9c6af77f..374dae8ff 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -303,6 +303,8 @@ from fastapi.security import OAuth2PasswordBearer from fastapi.security.api_key import APIKeyHeader from fastapi.staticfiles import StaticFiles +from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback + # import enterprise folder try: # when using litellm cli @@ -747,118 +749,6 @@ async def _PROXY_failure_handler( pass -@log_db_metrics -async def _PROXY_track_cost_callback( - kwargs, # kwargs to completion - completion_response: litellm.ModelResponse, # response from completion - start_time=None, - end_time=None, # start/end time for completion -): - verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") - global prisma_client - try: - verbose_proxy_logger.debug( - f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" - ) - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) - litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) - metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - user_id = metadata.get("user_api_key_user_id", None) - team_id = metadata.get("user_api_key_team_id", None) - org_id = metadata.get("user_api_key_org_id", None) - key_alias = metadata.get("user_api_key_alias", None) - end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) - sl_object: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object", None - ) - response_cost = ( - sl_object.get("response_cost", None) - if sl_object is not None - else kwargs.get("response_cost", None) - ) - - if response_cost is not None: - user_api_key = metadata.get("user_api_key", None) - if kwargs.get("cache_hit", False) is True: - response_cost = 0.0 - verbose_proxy_logger.info( - f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" - ) - - verbose_proxy_logger.debug( - f"user_api_key {user_api_key}, prisma_client: {prisma_client}" - ) - if user_api_key is not None or user_id is not None or team_id is not None: - ## UPDATE DATABASE - await update_database( - token=user_api_key, - response_cost=response_cost, - user_id=user_id, - end_user_id=end_user_id, - team_id=team_id, - kwargs=kwargs, - completion_response=completion_response, - start_time=start_time, - end_time=end_time, - org_id=org_id, - ) - - # update cache - asyncio.create_task( - update_cache( - token=user_api_key, - user_id=user_id, - end_user_id=end_user_id, - response_cost=response_cost, - team_id=team_id, - parent_otel_span=parent_otel_span, - ) - ) - - await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( - token=user_api_key, - key_alias=key_alias, - end_user_id=end_user_id, - response_cost=response_cost, - max_budget=end_user_max_budget, - ) - else: - raise Exception( - "User API key and team id and user id missing from custom callback." - ) - else: - if kwargs["stream"] is not True or ( - kwargs["stream"] is True and "complete_streaming_response" in kwargs - ): - if sl_object is not None: - cost_tracking_failure_debug_info: Union[dict, str] = ( - sl_object["response_cost_failure_debug_info"] # type: ignore - or "response_cost_failure_debug_info is None in standard_logging_object" - ) - else: - cost_tracking_failure_debug_info = ( - "standard_logging_object not found" - ) - model = kwargs.get("model") - raise Exception( - f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" - ) - except Exception as e: - error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" - model = kwargs.get("model", "") - metadata = kwargs.get("litellm_params", {}).get("metadata", {}) - error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n" - asyncio.create_task( - proxy_logging_obj.failed_tracking_alert( - error_message=error_msg, - failing_model=model, - ) - ) - verbose_proxy_logger.debug(error_msg) - - def error_tracking(): global prisma_client if prisma_client is not None: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index e3df357be..124826003 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1460,7 +1460,28 @@ class AdapterCompletionStreamWrapper: raise StopAsyncIteration +class StandardLoggingBudgetMetadata(TypedDict, total=False): + """ + Store Budget related metadata for Team, Internal User, End User etc + """ + + user_api_end_user_max_budget: Optional[float] + + class StandardLoggingUserAPIKeyMetadata(TypedDict): + """ + Store User API Key related metadata to identify the request + + Example: + user_api_key_hash: "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b" + user_api_key_alias: "litellm-key-123" + user_api_key_org_id: "123" + user_api_key_team_id: "456" + user_api_key_user_id: "789" + user_api_key_team_alias: "litellm-team-123" + + """ + user_api_key_hash: Optional[str] # hash of the litellm virtual key used user_api_key_alias: Optional[str] user_api_key_org_id: Optional[str] @@ -1469,7 +1490,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict): user_api_key_team_alias: Optional[str] -class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): +class StandardLoggingMetadata( + StandardLoggingUserAPIKeyMetadata, StandardLoggingBudgetMetadata +): """ Specific metadata k,v pairs logged to integration for easier cost tracking """ diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index a1e136313..a83f7b4bc 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -107,6 +107,12 @@ from litellm.proxy._types import ( UpdateUserRequest, UserAPIKeyAuth, ) +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingModelInformation, + StandardLoggingMetadata, + StandardLoggingHiddenParams, +) proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) @@ -143,6 +149,58 @@ def prisma_client(): return prisma_client +def create_simple_standard_logging_payload() -> StandardLoggingPayload: + + return StandardLoggingPayload( + id="test_id", + call_type="completion", + response_cost=0.1, + response_cost_failure_debug_info=None, + status="success", + total_tokens=30, + prompt_tokens=20, + completion_tokens=10, + startTime=1234567890.0, + endTime=1234567891.0, + completionStartTime=1234567890.5, + model_map_information=StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", model_map_value=None + ), + model="gpt-3.5-turbo", + model_id="model-123", + model_group="openai-gpt", + api_base="https://api.openai.com", + metadata=StandardLoggingMetadata( + user_api_key_hash="test_hash", + user_api_key_org_id=None, + user_api_key_alias="test_alias", + user_api_key_team_id="test_team", + user_api_key_user_id="test_user", + user_api_key_team_alias="test_team_alias", + spend_logs_metadata=None, + requester_ip_address="127.0.0.1", + requester_metadata=None, + ), + cache_hit=False, + cache_key=None, + saved_cache_cost=0.0, + request_tags=[], + end_user=None, + requester_ip_address="127.0.0.1", + messages=[{"role": "user", "content": "Hello, world!"}], + response={"choices": [{"message": {"content": "Hi there!"}}]}, + error_str=None, + model_parameters={"stream": True}, + hidden_params=StandardLoggingHiddenParams( + model_id="model-123", + cache_key=None, + api_base="https://api.openai.com", + response_cost="0.1", + additional_headers=None, + ), + ) + + @pytest.mark.asyncio() @pytest.mark.flaky(retries=6, delay=1) async def test_new_user_response(prisma_client): @@ -521,16 +579,17 @@ def test_call_with_user_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id + await track_cost_callback( kwargs={ "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -618,21 +677,17 @@ def test_call_with_end_user_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 10 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token="sk-1234" + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user + standard_logging_payload["end_user"] = user await track_cost_callback( kwargs={ "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": "sk-1234", - "user_api_key_user_id": user, - }, - "proxy_server_request": { - "body": { - "user": user, - } - }, - }, - "response_cost": 10, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -724,16 +779,16 @@ def test_call_with_proxy_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -815,17 +870,17 @@ def test_call_with_user_over_budget_stream(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=ModelResponse(), start_time=datetime.now(), @@ -921,17 +976,17 @@ def test_call_with_proxy_over_budget_stream(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=ModelResponse(), start_time=datetime.now(), @@ -1493,17 +1548,17 @@ def test_call_with_key_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1610,17 +1665,17 @@ def test_call_with_key_over_budget_no_cache(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1734,10 +1789,17 @@ def test_call_with_key_over_model_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, + "standard_logging_object": standard_logging_payload, "litellm_params": { "metadata": { "user_api_key": hash_token(generated_key), @@ -1840,17 +1902,17 @@ async def test_call_with_key_never_over_budget(prisma_client): prompt_tokens=210000, completion_tokens=200000, total_tokens=41000 ), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 200000 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 200000, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1921,19 +1983,19 @@ async def test_call_with_key_over_budget_stream(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00005 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "call_type": "acompletion", "model": "sagemaker-chatgpt-v-2", "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00005, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -2329,19 +2391,19 @@ async def track_cost_callback_helper_fn(generated_key: str, user_id: str): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00005 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "call_type": "acompletion", "model": "sagemaker-chatgpt-v-2", "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00005, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(),