From afada01ffc40680e606ab5e18d5177430cb066a1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 23 Jan 2024 14:39:35 -0800 Subject: [PATCH] fix(utils.py): fix streaming cost tracking --- litellm/proxy/proxy_server.py | 14 +++++++++----- litellm/proxy/utils.py | 4 ++-- litellm/utils.py | 32 ++++++++++++++++++++++++++------ pyproject.toml | 2 ++ 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 398905e1a..df8b63b64 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -572,7 +572,7 @@ async def track_cost_callback( litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} user_id = proxy_server_request.get("body", {}).get("user", None) - if "response_cost" in kwargs: + if kwargs.get("response_cost", None) is not None: response_cost = kwargs["response_cost"] user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None @@ -598,9 +598,13 @@ async def track_cost_callback( end_time=end_time, ) else: - raise Exception( - f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" - ) + if ( + kwargs["stream"] != True + or kwargs.get("complete_streaming_response", None) is not None + ): + raise Exception( + f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) except Exception as e: verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") @@ -1514,7 +1518,7 @@ async def startup_event(): duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) verbose_proxy_logger.debug( - f"custom_db_client client - Inserting master key {custom_db_client}. Master_key: {master_key}" + f"custom_db_client client {custom_db_client}. Master_key: {master_key}" ) if custom_db_client is not None and master_key is not None: # add master key to db diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 109141079..c06bed7fa 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -961,9 +961,9 @@ def _duration_in_seconds(duration: str): async def reset_budget(prisma_client: PrismaClient): """ - Gets all the non-expired keys for a db, which need budget to be reset + Gets all the non-expired keys for a db, which need spend to be reset - Resets their budget + Resets their spend Updates db """ diff --git a/litellm/utils.py b/litellm/utils.py index 00b76bfb5..7f8a447ad 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1067,10 +1067,15 @@ class Logging: ## if model in model cost map - log the response cost ## else set cost to None verbose_logger.debug(f"Model={self.model}; result={result}") - if result is not None and ( - isinstance(result, ModelResponse) - or isinstance(result, EmbeddingResponse) - ): + verbose_logger.debug(f"self.stream: {self.stream}") + if ( + result is not None + and ( + isinstance(result, ModelResponse) + or isinstance(result, EmbeddingResponse) + ) + and self.stream != True + ): # handle streaming separately try: self.model_call_details["response_cost"] = litellm.completion_cost( completion_response=result, @@ -1125,7 +1130,7 @@ class Logging: else: self.sync_streaming_chunks.append(result) - if complete_streaming_response: + if complete_streaming_response is not None: verbose_logger.debug( f"Logging Details LiteLLM-Success Call streaming complete" ) @@ -1418,11 +1423,23 @@ class Logging: complete_streaming_response = None else: self.streaming_chunks.append(result) - if complete_streaming_response: + if complete_streaming_response is not None: print_verbose("Async success callbacks: Got a complete streaming response") self.model_call_details[ "complete_streaming_response" ] = complete_streaming_response + try: + self.model_call_details["response_cost"] = litellm.completion_cost( + completion_response=complete_streaming_response, + ) + verbose_logger.debug( + f"Model={self.model}; cost={self.model_call_details['response_cost']}" + ) + except litellm.NotFoundError as e: + verbose_logger.debug( + f"Model={self.model} not found in completion cost map." + ) + self.model_call_details["response_cost"] = None for callback in litellm._async_success_callback: try: @@ -2867,6 +2884,9 @@ def cost_per_token( if model in model_cost_ref: verbose_logger.debug(f"Success: model={model} in model_cost_map") + verbose_logger.debug( + f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}" + ) if ( model_cost_ref[model].get("input_cost_per_token", None) is not None and model_cost_ref[model].get("output_cost_per_token", None) is not None diff --git a/pyproject.toml b/pyproject.toml index 00d424e79..0a18f6af1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ backoff = {version = "*", optional = true} pyyaml = {version = "^6.0.1", optional = true} rq = {version = "*", optional = true} orjson = {version = "^3.9.7", optional = true} +apscheduler = {version = "^3.10.4", optional = true} streamlit = {version = "^1.29.0", optional = true} [tool.poetry.extras] @@ -36,6 +37,7 @@ proxy = [ "pyyaml", "rq", "orjson", + "apscheduler" ] extra_proxy = [