fix(utils.py): fix streaming cost tracking

This commit is contained in:
Krrish Dholakia 2024-01-23 14:39:35 -08:00
parent 7eb96e46a4
commit afada01ffc
4 changed files with 39 additions and 13 deletions

View file

@ -572,7 +572,7 @@ async def track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
user_id = proxy_server_request.get("body", {}).get("user", None) user_id = proxy_server_request.get("body", {}).get("user", None)
if "response_cost" in kwargs: if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"] response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get( user_api_key = kwargs["litellm_params"]["metadata"].get(
"user_api_key", None "user_api_key", None
@ -598,9 +598,13 @@ async def track_cost_callback(
end_time=end_time, end_time=end_time,
) )
else: else:
raise Exception( if (
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" kwargs["stream"] != True
) or kwargs.get("complete_streaming_response", None) is not None
):
raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
@ -1514,7 +1518,7 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"custom_db_client client - Inserting master key {custom_db_client}. Master_key: {master_key}" f"custom_db_client client {custom_db_client}. Master_key: {master_key}"
) )
if custom_db_client is not None and master_key is not None: if custom_db_client is not None and master_key is not None:
# add master key to db # add master key to db

View file

@ -961,9 +961,9 @@ def _duration_in_seconds(duration: str):
async def reset_budget(prisma_client: PrismaClient): async def reset_budget(prisma_client: PrismaClient):
""" """
Gets all the non-expired keys for a db, which need budget to be reset Gets all the non-expired keys for a db, which need spend to be reset
Resets their budget Resets their spend
Updates db Updates db
""" """

View file

@ -1067,10 +1067,15 @@ class Logging:
## if model in model cost map - log the response cost ## if model in model cost map - log the response cost
## else set cost to None ## else set cost to None
verbose_logger.debug(f"Model={self.model}; result={result}") verbose_logger.debug(f"Model={self.model}; result={result}")
if result is not None and ( verbose_logger.debug(f"self.stream: {self.stream}")
isinstance(result, ModelResponse) if (
or isinstance(result, EmbeddingResponse) result is not None
): and (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
)
and self.stream != True
): # handle streaming separately
try: try:
self.model_call_details["response_cost"] = litellm.completion_cost( self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=result, completion_response=result,
@ -1125,7 +1130,7 @@ class Logging:
else: else:
self.sync_streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
if complete_streaming_response: if complete_streaming_response is not None:
verbose_logger.debug( verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete" f"Logging Details LiteLLM-Success Call streaming complete"
) )
@ -1418,11 +1423,23 @@ class Logging:
complete_streaming_response = None complete_streaming_response = None
else: else:
self.streaming_chunks.append(result) self.streaming_chunks.append(result)
if complete_streaming_response: if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response") print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=complete_streaming_response,
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
for callback in litellm._async_success_callback: for callback in litellm._async_success_callback:
try: try:
@ -2867,6 +2884,9 @@ def cost_per_token(
if model in model_cost_ref: if model in model_cost_ref:
verbose_logger.debug(f"Success: model={model} in model_cost_map") verbose_logger.debug(f"Success: model={model} in model_cost_map")
verbose_logger.debug(
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
)
if ( if (
model_cost_ref[model].get("input_cost_per_token", None) is not None model_cost_ref[model].get("input_cost_per_token", None) is not None
and model_cost_ref[model].get("output_cost_per_token", None) is not None and model_cost_ref[model].get("output_cost_per_token", None) is not None

View file

@ -25,6 +25,7 @@ backoff = {version = "*", optional = true}
pyyaml = {version = "^6.0.1", optional = true} pyyaml = {version = "^6.0.1", optional = true}
rq = {version = "*", optional = true} rq = {version = "*", optional = true}
orjson = {version = "^3.9.7", optional = true} orjson = {version = "^3.9.7", optional = true}
apscheduler = {version = "^3.10.4", optional = true}
streamlit = {version = "^1.29.0", optional = true} streamlit = {version = "^1.29.0", optional = true}
[tool.poetry.extras] [tool.poetry.extras]
@ -36,6 +37,7 @@ proxy = [
"pyyaml", "pyyaml",
"rq", "rq",
"orjson", "orjson",
"apscheduler"
] ]
extra_proxy = [ extra_proxy = [