forked from phoenix/litellm-mirror
Cost tracking improvements (#5828)
* feat(litellm_logging.py): update standard logging payload to include debug information for cost failures Also includes fixes for cohere rerank cost tracking + databricks llama2 model cost tracking Easier to repro cost failures and improve reliability in prod * fix(proxy_server.py): emit cost failure debug info for slack alerting Improves debug information for cost tracking failures, on slack alerting
This commit is contained in:
parent
8039b95aaf
commit
2488e4b45f
6 changed files with 117 additions and 45 deletions
|
@ -250,6 +250,13 @@ def cost_per_token(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return prompt_cost, completion_cost
|
return prompt_cost, completion_cost
|
||||||
|
elif call_type == "arerank" or call_type == "rerank":
|
||||||
|
completion_tokens_cost_usd_dollar = rerank_cost(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
prompt_tokens_cost_usd_dollar = 0
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model in model_cost_ref:
|
elif model in model_cost_ref:
|
||||||
print_verbose(f"Success: model={model} in model_cost_map")
|
print_verbose(f"Success: model={model} in model_cost_map")
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -689,7 +696,18 @@ def completion_cost(
|
||||||
call_type == CallTypes.speech.value or call_type == CallTypes.aspeech.value
|
call_type == CallTypes.speech.value or call_type == CallTypes.aspeech.value
|
||||||
):
|
):
|
||||||
prompt_characters = litellm.utils._count_characters(text=prompt)
|
prompt_characters = litellm.utils._count_characters(text=prompt)
|
||||||
|
elif (
|
||||||
|
call_type == CallTypes.rerank.value or call_type == CallTypes.arerank.value
|
||||||
|
):
|
||||||
|
if completion_response is not None and isinstance(
|
||||||
|
completion_response, RerankResponse
|
||||||
|
):
|
||||||
|
meta_obj = completion_response.meta
|
||||||
|
billed_units = meta_obj.get("billed_units", {}) or {}
|
||||||
|
search_units = (
|
||||||
|
billed_units.get("search_units") or 1
|
||||||
|
) # cohere charges per request by default.
|
||||||
|
completion_tokens = search_units
|
||||||
# Calculate cost based on prompt_tokens, completion_tokens
|
# Calculate cost based on prompt_tokens, completion_tokens
|
||||||
if (
|
if (
|
||||||
"togethercomputer" in model
|
"togethercomputer" in model
|
||||||
|
@ -794,7 +812,7 @@ def response_cost_calculator(
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
Returns
|
Returns
|
||||||
- float or None: cost of response OR none if error.
|
- float or None: cost of response
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response_cost: float = 0.0
|
response_cost: float = 0.0
|
||||||
|
@ -810,15 +828,6 @@ def response_cost_calculator(
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
elif isinstance(response_object, RerankResponse) and (
|
|
||||||
call_type == "arerank" or call_type == "rerank"
|
|
||||||
):
|
|
||||||
response_cost = rerank_cost(
|
|
||||||
rerank_response=response_object,
|
|
||||||
model=model,
|
|
||||||
call_type=call_type,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if custom_pricing is True: # override defaults if custom pricing is set
|
if custom_pricing is True: # override defaults if custom pricing is set
|
||||||
base_model = model
|
base_model = model
|
||||||
|
@ -831,24 +840,12 @@ def response_cost_calculator(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
return response_cost
|
return response_cost
|
||||||
except litellm.NotFoundError as e:
|
|
||||||
verbose_logger.debug( # debug since it can be spammy in logs, for calls
|
|
||||||
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(
|
raise e
|
||||||
"litellm.cost_calculator.py::response_cost_calculator - Returning None. Exception occurred - {}/n{}".format(
|
|
||||||
str(e), traceback.format_exc()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def rerank_cost(
|
def rerank_cost(
|
||||||
rerank_response: RerankResponse,
|
|
||||||
model: str,
|
model: str,
|
||||||
call_type: Literal["rerank", "arerank"],
|
|
||||||
custom_llm_provider: Optional[str],
|
custom_llm_provider: Optional[str],
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -41,6 +41,7 @@ from litellm.types.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
StandardLoggingHiddenParams,
|
StandardLoggingHiddenParams,
|
||||||
StandardLoggingMetadata,
|
StandardLoggingMetadata,
|
||||||
|
StandardLoggingModelCostFailureDebugInformation,
|
||||||
StandardLoggingModelInformation,
|
StandardLoggingModelInformation,
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
StandardLoggingPayloadStatus,
|
StandardLoggingPayloadStatus,
|
||||||
|
@ -574,7 +575,7 @@ class Logging:
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
],
|
],
|
||||||
cache_hit: Optional[bool] = None,
|
cache_hit: Optional[bool] = None,
|
||||||
):
|
) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
Calculate response cost using result + logging object variables.
|
Calculate response cost using result + logging object variables.
|
||||||
|
|
||||||
|
@ -590,22 +591,53 @@ class Logging:
|
||||||
if cache_hit is None:
|
if cache_hit is None:
|
||||||
cache_hit = self.model_call_details.get("cache_hit", False)
|
cache_hit = self.model_call_details.get("cache_hit", False)
|
||||||
|
|
||||||
response_cost = litellm.response_cost_calculator(
|
try:
|
||||||
response_object=result,
|
response_cost_calculator_kwargs = {
|
||||||
model=self.model,
|
"response_object": result,
|
||||||
cache_hit=cache_hit,
|
"model": self.model,
|
||||||
custom_llm_provider=self.model_call_details.get(
|
"cache_hit": cache_hit,
|
||||||
|
"custom_llm_provider": self.model_call_details.get(
|
||||||
"custom_llm_provider", None
|
"custom_llm_provider", None
|
||||||
),
|
),
|
||||||
base_model=_get_base_model_from_metadata(
|
"base_model": _get_base_model_from_metadata(
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
),
|
),
|
||||||
call_type=self.call_type,
|
"call_type": self.call_type,
|
||||||
optional_params=self.optional_params,
|
"optional_params": self.optional_params,
|
||||||
custom_pricing=custom_pricing,
|
"custom_pricing": custom_pricing,
|
||||||
|
}
|
||||||
|
except Exception as e: # error creating kwargs for cost calculation
|
||||||
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
|
StandardLoggingModelCostFailureDebugInformation(
|
||||||
|
error_str=str(e),
|
||||||
|
traceback_str=traceback.format_exc(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_cost = litellm.response_cost_calculator(
|
||||||
|
**response_cost_calculator_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_cost
|
return response_cost
|
||||||
|
except Exception as e: # error calculating cost
|
||||||
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
|
StandardLoggingModelCostFailureDebugInformation(
|
||||||
|
error_str=str(e),
|
||||||
|
traceback_str=traceback.format_exc(),
|
||||||
|
model=response_cost_calculator_kwargs["model"],
|
||||||
|
cache_hit=response_cost_calculator_kwargs["cache_hit"],
|
||||||
|
custom_llm_provider=response_cost_calculator_kwargs[
|
||||||
|
"custom_llm_provider"
|
||||||
|
],
|
||||||
|
base_model=response_cost_calculator_kwargs["base_model"],
|
||||||
|
call_type=response_cost_calculator_kwargs["call_type"],
|
||||||
|
custom_pricing=response_cost_calculator_kwargs["custom_pricing"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _success_handler_helper_fn(
|
def _success_handler_helper_fn(
|
||||||
self, result=None, start_time=None, end_time=None, cache_hit=None
|
self, result=None, start_time=None, end_time=None, cache_hit=None
|
||||||
|
@ -2501,12 +2533,16 @@ def get_standard_logging_object_payload(
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
||||||
"Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload"
|
"Model={} is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload".format(
|
||||||
|
model_cost_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
model_cost_information = StandardLoggingModelInformation(
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
model_map_key=model_cost_name, model_map_value=None
|
model_map_key=model_cost_name, model_map_value=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_cost: float = kwargs.get("response_cost", 0) or 0.0
|
||||||
|
|
||||||
payload: StandardLoggingPayload = StandardLoggingPayload(
|
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||||
id=str(id),
|
id=str(id),
|
||||||
call_type=call_type or "",
|
call_type=call_type or "",
|
||||||
|
@ -2519,7 +2555,7 @@ def get_standard_logging_object_payload(
|
||||||
model=kwargs.get("model", "") or "",
|
model=kwargs.get("model", "") or "",
|
||||||
metadata=clean_metadata,
|
metadata=clean_metadata,
|
||||||
cache_key=cache_key,
|
cache_key=cache_key,
|
||||||
response_cost=kwargs.get("response_cost", 0),
|
response_cost=response_cost,
|
||||||
total_tokens=usage.get("total_tokens", 0),
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
completion_tokens=usage.get("completion_tokens", 0),
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
@ -2537,6 +2573,9 @@ def get_standard_logging_object_payload(
|
||||||
hidden_params=clean_hidden_params,
|
hidden_params=clean_hidden_params,
|
||||||
model_map_information=model_cost_information,
|
model_map_information=model_cost_information,
|
||||||
error_str=error_str,
|
error_str=error_str,
|
||||||
|
response_cost_failure_debug_info=kwargs.get(
|
||||||
|
"response_cost_failure_debug_information"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
|
|
@ -49,6 +49,10 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||||
"gte-large-en"
|
"gte-large-en"
|
||||||
):
|
):
|
||||||
base_model = "databricks-gte-large-en"
|
base_model = "databricks-gte-large-en"
|
||||||
|
elif model.startswith("databricks/llama-2-70b-chat") or model.startswith(
|
||||||
|
"llama-2-70b-chat"
|
||||||
|
):
|
||||||
|
base_model = "databricks-llama-2-70b-chat"
|
||||||
## GET MODEL INFO
|
## GET MODEL INFO
|
||||||
model_info = get_model_info(model=base_model, custom_llm_provider="databricks")
|
model_info = get_model_info(model=base_model, custom_llm_provider="databricks")
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,14 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: cohere/rerank-english-v3.0
|
model: cohere/rerank-english-v3.0
|
||||||
api_key: os.environ/COHERE_API_KEY
|
api_key: os.environ/COHERE_API_KEY
|
||||||
|
- model_name: "databricks/*"
|
||||||
|
litellm_params:
|
||||||
|
model: "databricks/*"
|
||||||
|
api_key: os.environ/DATABRICKS_API_KEY
|
||||||
|
api_base: os.environ/DATABRICKS_API_BASE
|
||||||
|
- model_name: "anthropic/*"
|
||||||
|
litellm_params:
|
||||||
|
model: "anthropic/*"
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
|
|
@ -824,11 +824,15 @@ async def _PROXY_track_cost_callback(
|
||||||
"User API key and team id and user id missing from custom callback."
|
"User API key and team id and user id missing from custom callback."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if kwargs["stream"] != True or (
|
if kwargs["stream"] is not True or (
|
||||||
kwargs["stream"] == True and "complete_streaming_response" in kwargs
|
kwargs["stream"] is True and "complete_streaming_response" in kwargs
|
||||||
):
|
):
|
||||||
|
cost_tracking_failure_debug_info = kwargs.get(
|
||||||
|
"response_cost_failure_debug_information"
|
||||||
|
)
|
||||||
|
model = kwargs.get("model")
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Model not in litellm model cost map. Passed model = {kwargs.get('model')} - Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"error in tracking cost callback - {traceback.format_exc()}"
|
error_msg = f"error in tracking cost callback - {traceback.format_exc()}"
|
||||||
|
|
|
@ -1281,6 +1281,23 @@ class StandardLoggingModelInformation(TypedDict):
|
||||||
model_map_value: Optional[ModelInfo]
|
model_map_value: Optional[ModelInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingModelCostFailureDebugInformation(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Debug information, if cost tracking fails.
|
||||||
|
|
||||||
|
Avoid logging sensitive information like response or optional params
|
||||||
|
"""
|
||||||
|
|
||||||
|
error_str: Required[str]
|
||||||
|
traceback_str: Required[str]
|
||||||
|
model: str
|
||||||
|
cache_hit: Optional[bool]
|
||||||
|
custom_llm_provider: Optional[str]
|
||||||
|
base_model: Optional[str]
|
||||||
|
call_type: str
|
||||||
|
custom_pricing: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
StandardLoggingPayloadStatus = Literal["success", "failure"]
|
StandardLoggingPayloadStatus = Literal["success", "failure"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1288,6 +1305,9 @@ class StandardLoggingPayload(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
call_type: str
|
call_type: str
|
||||||
response_cost: float
|
response_cost: float
|
||||||
|
response_cost_failure_debug_info: Optional[
|
||||||
|
StandardLoggingModelCostFailureDebugInformation
|
||||||
|
]
|
||||||
status: StandardLoggingPayloadStatus
|
status: StandardLoggingPayloadStatus
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue