From 2639c1971da73a53213ea8af601bfb7305553ee0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 10:44:13 -0800 Subject: [PATCH 1/7] use separate file for _PROXY_track_cost_callback --- .../proxy/hooks/proxy_track_cost_callback.py | 137 ++++++++++++++++++ litellm/proxy/proxy_server.py | 106 +------------- 2 files changed, 139 insertions(+), 104 deletions(-) create mode 100644 litellm/proxy/hooks/proxy_track_cost_callback.py diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py new file mode 100644 index 000000000..40e339516 --- /dev/null +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -0,0 +1,137 @@ +""" +Proxy Success Callback - handles storing cost of a request in LiteLLM DB. + +Updates cost for the following in LiteLLM DB: + - spend logs + - virtual key spend + - internal user, team, external user spend +""" + +import asyncio +import traceback + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.utils import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, + log_to_opentelemetry, +) + + +@log_to_opentelemetry +async def _PROXY_track_cost_callback( + kwargs, # kwargs to completion + completion_response: litellm.ModelResponse, # response from completion + start_time=None, + end_time=None, # start/end time for completion +): + """ + Callback handles storing cost of a request in LiteLLM DB. + + Updates cost for the following in LiteLLM DB: + - spend logs + - virtual key spend + - internal user, team, external user spend + """ + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + update_cache, + update_database, + ) + + verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") + try: + # check if it has collected an entire stream response + verbose_proxy_logger.debug( + "Proxy: In track_cost_callback for: kwargs=%s and completion_response: %s", + kwargs, + completion_response, + ) + verbose_proxy_logger.debug( + f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) + litellm_params = kwargs.get("litellm_params", {}) or {} + proxy_server_request = litellm_params.get("proxy_server_request") or {} + end_user_id = proxy_server_request.get("body", {}).get("user", None) + metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) + user_id = metadata.get("user_api_key_user_id", None) + team_id = metadata.get("user_api_key_team_id", None) + org_id = metadata.get("user_api_key_org_id", None) + key_alias = metadata.get("user_api_key_alias", None) + end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) + if kwargs.get("response_cost", None) is not None: + response_cost = kwargs["response_cost"] + user_api_key = metadata.get("user_api_key", None) + if kwargs.get("cache_hit", False) is True: + response_cost = 0.0 + verbose_proxy_logger.info( + f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" + ) + + verbose_proxy_logger.debug( + f"user_api_key {user_api_key}, prisma_client: {prisma_client}" + ) + if user_api_key is not None or user_id is not None or team_id is not None: + ## UPDATE DATABASE + await update_database( + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + end_user_id=end_user_id, + team_id=team_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + # update cache + asyncio.create_task( + update_cache( + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, + team_id=team_id, + parent_otel_span=parent_otel_span, + ) + ) + + await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( + token=user_api_key, + key_alias=key_alias, + end_user_id=end_user_id, + response_cost=response_cost, + max_budget=end_user_max_budget, + ) + else: + raise Exception( + "User API key and team id and user id missing from custom callback." + ) + else: + if kwargs["stream"] is not True or ( + kwargs["stream"] is True and "complete_streaming_response" in kwargs + ): + cost_tracking_failure_debug_info = kwargs.get( + "response_cost_failure_debug_information" + ) + model = kwargs.get("model") + raise Exception( + f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) + except Exception as e: + error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" + model = kwargs.get("model", "") + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n" + asyncio.create_task( + proxy_logging_obj.failed_tracking_alert( + error_message=error_msg, + failing_model=model, + ) + ) + verbose_proxy_logger.debug("error in tracking cost callback - %s", e) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ca6befef6..9db33a5a6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -303,6 +303,8 @@ from fastapi.security import OAuth2PasswordBearer from fastapi.security.api_key import APIKeyHeader from fastapi.staticfiles import StaticFiles +from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback + # import enterprise folder try: # when using litellm cli @@ -747,110 +749,6 @@ async def _PROXY_failure_handler( pass -@log_to_opentelemetry -async def _PROXY_track_cost_callback( - kwargs, # kwargs to completion - completion_response: litellm.ModelResponse, # response from completion - start_time=None, - end_time=None, # start/end time for completion -): - verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") - global prisma_client - try: - # check if it has collected an entire stream response - verbose_proxy_logger.debug( - "Proxy: In track_cost_callback for: kwargs=%s and completion_response: %s", - kwargs, - completion_response, - ) - verbose_proxy_logger.debug( - f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" - ) - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) - litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) - metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - user_id = metadata.get("user_api_key_user_id", None) - team_id = metadata.get("user_api_key_team_id", None) - org_id = metadata.get("user_api_key_org_id", None) - key_alias = metadata.get("user_api_key_alias", None) - end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) - if kwargs.get("response_cost", None) is not None: - response_cost = kwargs["response_cost"] - user_api_key = metadata.get("user_api_key", None) - if kwargs.get("cache_hit", False) is True: - response_cost = 0.0 - verbose_proxy_logger.info( - f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" - ) - - verbose_proxy_logger.debug( - f"user_api_key {user_api_key}, prisma_client: {prisma_client}" - ) - if user_api_key is not None or user_id is not None or team_id is not None: - ## UPDATE DATABASE - await update_database( - token=user_api_key, - response_cost=response_cost, - user_id=user_id, - end_user_id=end_user_id, - team_id=team_id, - kwargs=kwargs, - completion_response=completion_response, - start_time=start_time, - end_time=end_time, - org_id=org_id, - ) - - # update cache - asyncio.create_task( - update_cache( - token=user_api_key, - user_id=user_id, - end_user_id=end_user_id, - response_cost=response_cost, - team_id=team_id, - parent_otel_span=parent_otel_span, - ) - ) - - await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( - token=user_api_key, - key_alias=key_alias, - end_user_id=end_user_id, - response_cost=response_cost, - max_budget=end_user_max_budget, - ) - else: - raise Exception( - "User API key and team id and user id missing from custom callback." - ) - else: - if kwargs["stream"] is not True or ( - kwargs["stream"] is True and "complete_streaming_response" in kwargs - ): - cost_tracking_failure_debug_info = kwargs.get( - "response_cost_failure_debug_information" - ) - model = kwargs.get("model") - raise Exception( - f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" - ) - except Exception as e: - error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" - model = kwargs.get("model", "") - metadata = kwargs.get("litellm_params", {}).get("metadata", {}) - error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n" - asyncio.create_task( - proxy_logging_obj.failed_tracking_alert( - error_message=error_msg, - failing_model=model, - ) - ) - verbose_proxy_logger.debug("error in tracking cost callback - %s", e) - - def error_tracking(): global prisma_client if prisma_client is not None: From 9864459f4debc8236e3873f2be046c042187f5ab Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 11:01:58 -0800 Subject: [PATCH 2/7] fix use standard_logging_payload for track cost callback --- .../proxy/hooks/proxy_track_cost_callback.py | 70 ++++++++++--------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 40e339516..6d5b441cd 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -9,6 +9,7 @@ Updates cost for the following in LiteLLM DB: import asyncio import traceback +from typing import Optional import litellm from litellm._logging import verbose_proxy_logger @@ -17,6 +18,7 @@ from litellm.proxy.utils import ( get_litellm_metadata_from_kwargs, log_to_opentelemetry, ) +from litellm.types.utils import StandardLoggingPayload @log_to_opentelemetry @@ -49,31 +51,36 @@ async def _PROXY_track_cost_callback( kwargs, completion_response, ) - verbose_proxy_logger.debug( - f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" - ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) - litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) - metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) - user_id = metadata.get("user_api_key_user_id", None) - team_id = metadata.get("user_api_key_team_id", None) - org_id = metadata.get("user_api_key_org_id", None) - key_alias = metadata.get("user_api_key_alias", None) - end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) - if kwargs.get("response_cost", None) is not None: - response_cost = kwargs["response_cost"] - user_api_key = metadata.get("user_api_key", None) - if kwargs.get("cache_hit", False) is True: - response_cost = 0.0 - verbose_proxy_logger.info( - f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" - ) - - verbose_proxy_logger.debug( - f"user_api_key {user_api_key}, prisma_client: {prisma_client}" + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError( + "standard_logging_payload is none in kwargs, cannot track cost without it" ) + end_user_id = standard_logging_payload.get("end_user") + user_api_key = standard_logging_payload.get("metadata", {}).get( + "user_api_key_hash" + ) + user_id = standard_logging_payload.get("metadata", {}).get( + "user_api_key_user_id" + ) + team_id = standard_logging_payload.get("metadata", {}).get( + "user_api_key_team_id" + ) + org_id = standard_logging_payload.get("metadata", {}).get("user_api_key_org_id") + key_alias = standard_logging_payload.get("metadata", {}).get( + "user_api_key_alias" + ) + response_cost: Optional[float] = standard_logging_payload.get("response_cost") + + # end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) + end_user_max_budget = standard_logging_payload.get("metadata", {}).get( + "user_api_end_user_max_budget" + ) + + if response_cost is not None: if user_api_key is not None or user_id is not None or team_id is not None: ## UPDATE DATABASE await update_database( @@ -113,16 +120,13 @@ async def _PROXY_track_cost_callback( "User API key and team id and user id missing from custom callback." ) else: - if kwargs["stream"] is not True or ( - kwargs["stream"] is True and "complete_streaming_response" in kwargs - ): - cost_tracking_failure_debug_info = kwargs.get( - "response_cost_failure_debug_information" - ) - model = kwargs.get("model") - raise Exception( - f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" - ) + cost_tracking_failure_debug_info = standard_logging_payload.get( + "response_cost_failure_debug_info" + ) + model = kwargs.get("model") + raise ValueError( + f"Failed to write cost to DB, for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) except Exception as e: error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" model = kwargs.get("model", "") From 02cf18be83c7efb06dafe8feb7a5205a056158ea Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 11:11:09 -0800 Subject: [PATCH 3/7] StandardLoggingBudgetMetadata --- .../proxy/hooks/proxy_track_cost_callback.py | 4 +-- litellm/types/utils.py | 25 ++++++++++++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 6d5b441cd..2de2b0673 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -73,12 +73,10 @@ async def _PROXY_track_cost_callback( key_alias = standard_logging_payload.get("metadata", {}).get( "user_api_key_alias" ) - response_cost: Optional[float] = standard_logging_payload.get("response_cost") - - # end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) end_user_max_budget = standard_logging_payload.get("metadata", {}).get( "user_api_end_user_max_budget" ) + response_cost: Optional[float] = standard_logging_payload.get("response_cost") if response_cost is not None: if user_api_key is not None or user_id is not None or team_id is not None: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 2d0e262fe..2b5a1cdfd 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1416,7 +1416,28 @@ class AdapterCompletionStreamWrapper: raise StopAsyncIteration +class StandardLoggingBudgetMetadata(TypedDict, total=False): + """ + Store Budget related metadata for Team, Internal User, End User etc + """ + + user_api_end_user_max_budget: Optional[float] + + class StandardLoggingUserAPIKeyMetadata(TypedDict): + """ + Store User API Key related metadata to identify the request + + Example: + user_api_key_hash: "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b" + user_api_key_alias: "litellm-key-123" + user_api_key_org_id: "123" + user_api_key_team_id: "456" + user_api_key_user_id: "789" + user_api_key_team_alias: "litellm-team-123" + + """ + user_api_key_hash: Optional[str] # hash of the litellm virtual key used user_api_key_alias: Optional[str] user_api_key_org_id: Optional[str] @@ -1425,7 +1446,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict): user_api_key_team_alias: Optional[str] -class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): +class StandardLoggingMetadata( + StandardLoggingUserAPIKeyMetadata, StandardLoggingBudgetMetadata +): """ Specific metadata k,v pairs logged to integration for easier cost tracking """ From 9116c09386314f5a746e5b9649290b82b4de0330 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 11:33:27 -0800 Subject: [PATCH 4/7] fix test key gen prisma --- .../local_testing/test_key_generate_prisma.py | 219 +++++++++++------- 1 file changed, 137 insertions(+), 82 deletions(-) diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index 74182c09f..0cb9659dc 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -105,6 +105,12 @@ from litellm.proxy._types import ( UpdateUserRequest, UserAPIKeyAuth, ) +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingModelInformation, + StandardLoggingMetadata, + StandardLoggingHiddenParams, +) proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) @@ -141,6 +147,58 @@ def prisma_client(): return prisma_client +def create_simple_standard_logging_payload() -> StandardLoggingPayload: + + return StandardLoggingPayload( + id="test_id", + call_type="completion", + response_cost=0.1, + response_cost_failure_debug_info=None, + status="success", + total_tokens=30, + prompt_tokens=20, + completion_tokens=10, + startTime=1234567890.0, + endTime=1234567891.0, + completionStartTime=1234567890.5, + model_map_information=StandardLoggingModelInformation( + model_map_key="gpt-3.5-turbo", model_map_value=None + ), + model="gpt-3.5-turbo", + model_id="model-123", + model_group="openai-gpt", + api_base="https://api.openai.com", + metadata=StandardLoggingMetadata( + user_api_key_hash="test_hash", + user_api_key_org_id=None, + user_api_key_alias="test_alias", + user_api_key_team_id="test_team", + user_api_key_user_id="test_user", + user_api_key_team_alias="test_team_alias", + spend_logs_metadata=None, + requester_ip_address="127.0.0.1", + requester_metadata=None, + ), + cache_hit=False, + cache_key=None, + saved_cache_cost=0.0, + request_tags=[], + end_user=None, + requester_ip_address="127.0.0.1", + messages=[{"role": "user", "content": "Hello, world!"}], + response={"choices": [{"message": {"content": "Hi there!"}}]}, + error_str=None, + model_parameters={"stream": True}, + hidden_params=StandardLoggingHiddenParams( + model_id="model-123", + cache_key=None, + api_base="https://api.openai.com", + response_cost="0.1", + additional_headers=None, + ), + ) + + @pytest.mark.asyncio() @pytest.mark.flaky(retries=6, delay=1) async def test_new_user_response(prisma_client): @@ -514,16 +572,17 @@ def test_call_with_user_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id + await track_cost_callback( kwargs={ "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -611,21 +670,17 @@ def test_call_with_end_user_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 10 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token="sk-1234" + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user + standard_logging_payload["end_user"] = user await track_cost_callback( kwargs={ "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": "sk-1234", - "user_api_key_user_id": user, - }, - "proxy_server_request": { - "body": { - "user": user, - } - }, - }, - "response_cost": 10, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -717,16 +772,16 @@ def test_call_with_proxy_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -808,17 +863,17 @@ def test_call_with_user_over_budget_stream(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=ModelResponse(), start_time=datetime.now(), @@ -914,17 +969,17 @@ def test_call_with_proxy_over_budget_stream(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": generated_key, - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=ModelResponse(), start_time=datetime.now(), @@ -1471,17 +1526,17 @@ def test_call_with_key_over_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1588,17 +1643,17 @@ def test_call_with_key_over_budget_no_cache(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1712,17 +1767,17 @@ def test_call_with_key_over_model_budget(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00002 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00002, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1818,17 +1873,17 @@ async def test_call_with_key_never_over_budget(prisma_client): prompt_tokens=210000, completion_tokens=200000, total_tokens=41000 ), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 200000 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "model": "chatgpt-v-2", "stream": False, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 200000, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -1899,19 +1954,19 @@ async def test_call_with_key_over_budget_stream(prisma_client): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00005 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "call_type": "acompletion", "model": "sagemaker-chatgpt-v-2", "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00005, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), @@ -2292,19 +2347,19 @@ async def track_cost_callback_helper_fn(generated_key: str, user_id: str): model="gpt-35-turbo", # azure always has model written like this usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), ) + standard_logging_payload = create_simple_standard_logging_payload() + standard_logging_payload["response_cost"] = 0.00005 + standard_logging_payload["metadata"]["user_api_key_hash"] = hash_token( + token=generated_key + ) + standard_logging_payload["metadata"]["user_api_key_user_id"] = user_id await track_cost_callback( kwargs={ "call_type": "acompletion", "model": "sagemaker-chatgpt-v-2", "stream": True, "complete_streaming_response": resp, - "litellm_params": { - "metadata": { - "user_api_key": hash_token(generated_key), - "user_api_key_user_id": user_id, - } - }, - "response_cost": 0.00005, + "standard_logging_object": standard_logging_payload, }, completion_response=resp, start_time=datetime.now(), From 501bf6961fd552f74d7dec2bd03559bdeb3f6fe7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 12:01:46 -0800 Subject: [PATCH 5/7] fix test_call_with_key_over_model_budget --- tests/local_testing/test_key_generate_prisma.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/local_testing/test_key_generate_prisma.py b/tests/local_testing/test_key_generate_prisma.py index 0cb9659dc..fa9392136 100644 --- a/tests/local_testing/test_key_generate_prisma.py +++ b/tests/local_testing/test_key_generate_prisma.py @@ -1778,6 +1778,13 @@ def test_call_with_key_over_model_budget(prisma_client): "model": "chatgpt-v-2", "stream": False, "standard_logging_object": standard_logging_payload, + "litellm_params": { + "metadata": { + "user_api_key": hash_token(generated_key), + "user_api_key_user_id": user_id, + } + }, + "response_cost": 0.00002, }, completion_response=resp, start_time=datetime.now(), From 69aa10d53620c38d8516988441b7e8a884955993 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 4 Nov 2024 13:44:30 -0800 Subject: [PATCH 6/7] fix test_check_num_callbacks_on_lowest_latency --- tests/test_callbacks_on_proxy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_callbacks_on_proxy.py b/tests/test_callbacks_on_proxy.py index 42665c35b..3677a99ad 100644 --- a/tests/test_callbacks_on_proxy.py +++ b/tests/test_callbacks_on_proxy.py @@ -147,6 +147,7 @@ async def test_check_num_callbacks(): @pytest.mark.asyncio @pytest.mark.order2 +@pytest.mark.skip(reason="skipping this test for now") async def test_check_num_callbacks_on_lowest_latency(): """ Test 1: num callbacks should NOT increase over time From ae23c02b2f6dcd4ae20b5fd53a3b3046b9666b0b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 12 Nov 2024 10:34:08 -0800 Subject: [PATCH 7/7] fix merge updates --- litellm/proxy/hooks/proxy_track_cost_callback.py | 9 +++------ tests/test_callbacks_on_proxy.py | 1 - 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/hooks/proxy_track_cost_callback.py b/litellm/proxy/hooks/proxy_track_cost_callback.py index 2de2b0673..28518a68a 100644 --- a/litellm/proxy/hooks/proxy_track_cost_callback.py +++ b/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -13,15 +13,12 @@ from typing import Optional import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy.utils import ( - _get_parent_otel_span_from_kwargs, - get_litellm_metadata_from_kwargs, - log_to_opentelemetry, -) +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.proxy.utils import log_db_metrics from litellm.types.utils import StandardLoggingPayload -@log_to_opentelemetry +@log_db_metrics async def _PROXY_track_cost_callback( kwargs, # kwargs to completion completion_response: litellm.ModelResponse, # response from completion diff --git a/tests/test_callbacks_on_proxy.py b/tests/test_callbacks_on_proxy.py index 3677a99ad..42665c35b 100644 --- a/tests/test_callbacks_on_proxy.py +++ b/tests/test_callbacks_on_proxy.py @@ -147,7 +147,6 @@ async def test_check_num_callbacks(): @pytest.mark.asyncio @pytest.mark.order2 -@pytest.mark.skip(reason="skipping this test for now") async def test_check_num_callbacks_on_lowest_latency(): """ Test 1: num callbacks should NOT increase over time