From e917d0eee69ba55a98e87ce8af1025a47aecda41 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 22 Jan 2024 15:53:04 -0800 Subject: [PATCH] feat(utils.py): emit response cost as part of logs --- litellm/proxy/proxy_server.py | 34 ++------------------- litellm/tests/test_custom_callback_input.py | 8 +++-- litellm/utils.py | 7 +++++ 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9ef9a81581..8ae3ee7d36 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -562,13 +562,8 @@ async def track_cost_callback( litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} user_id = proxy_server_request.get("body", {}).get("user", None) - if "complete_streaming_response" in kwargs: - # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost - completion_response = kwargs["complete_streaming_response"] - response_cost = litellm.completion_cost( - completion_response=completion_response - ) - + if "response_cost" in kwargs: + response_cost = kwargs["response_cost"] user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None ) @@ -577,31 +572,6 @@ async def track_cost_callback( "user_api_key_user_id", None ) - verbose_proxy_logger.info( - f"streaming response_cost {response_cost}, for user_id {user_id}" - ) - if user_api_key and ( - prisma_client is not None or custom_db_client is not None - ): - await update_database( - token=user_api_key, - response_cost=response_cost, - user_id=user_id, - kwargs=kwargs, - completion_response=completion_response, - start_time=start_time, - end_time=end_time, - ) - elif kwargs["stream"] == False: # for non streaming responses - response_cost = litellm.completion_cost( - completion_response=completion_response - ) - user_api_key = kwargs["litellm_params"]["metadata"].get( - "user_api_key", None - ) - user_id = user_id or kwargs["litellm_params"]["metadata"].get( - "user_api_key_user_id", None - ) verbose_proxy_logger.info( f"response_cost {response_cost}, for user_id {user_id}" ) diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 0fb69b6451..e4bd9e2c1e 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -170,6 +170,7 @@ class CompletionCustomHandler( ) assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["log_event_type"], str) + assert isinstance(kwargs["response_cost"], (float, type(None))) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) @@ -262,6 +263,7 @@ class CompletionCustomHandler( assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["log_event_type"], str) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) + assert isinstance(kwargs["response_cost"], (float, type(None))) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) @@ -545,8 +547,9 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) -# Text Completion - +# Text Completion + + ## Test OpenAI text completion + Async @pytest.mark.asyncio async def test_async_text_completion_openai_stream(): @@ -585,6 +588,7 @@ async def test_async_text_completion_openai_stream(): except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # EMBEDDING ## Test OpenAI + Async @pytest.mark.asyncio diff --git a/litellm/utils.py b/litellm/utils.py index 468e671363..28b514d402 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1064,6 +1064,13 @@ class Logging: self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit + if result is not None and ( + isinstance(result, ModelResponse) + or isinstance(result, EmbeddingResponse) + ): + self.model_call_details["response_cost"] = litellm.completion_cost( + completion_response=result, + ) if litellm.max_budget and self.stream: time_diff = (end_time - start_time).total_seconds()