Fix batches api cost tracking + Log batch models in spend logs / standard logging payload (#9077)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 42s

* feat(batches/): fix batch cost calculation - ensure it's accurate

use the correct cost value - prev. defaulting to non-batch cost

* feat(batch_utils.py): log batch models to spend logs + standard logging payload

makes it easy to understand how cost was calculated

* fix: fix stored payload for test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-03-08 11:47:25 -08:00 committed by GitHub
parent 8c049dfffc
commit 4330ef8e81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 110 additions and 7 deletions

View file

@ -4,13 +4,13 @@ from typing import Any, List, Literal, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import Batch
from litellm.types.utils import Usage
from litellm.types.utils import CallTypes, Usage
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
) -> Tuple[float, Usage]:
) -> Tuple[float, Usage, List[str]]:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
@ -27,7 +27,25 @@ async def _handle_completed_batch(
custom_llm_provider=custom_llm_provider,
)
return batch_cost, batch_usage
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
return batch_cost, batch_usage, batch_models
def _get_batch_models_from_file_content(
file_content_dictionary: List[dict],
) -> List[str]:
"""
Get the models from the file content
"""
batch_models = []
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
_model = _response_body.get("model")
if _model:
batch_models.append(_model)
return batch_models
async def _batch_cost_calculator(
@ -105,6 +123,7 @@ def _get_batch_job_cost_from_file_content(
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost

View file

@ -239,6 +239,15 @@ def cost_per_token( # noqa: PLR0915
custom_llm_provider=custom_llm_provider,
billed_units=rerank_billed_units,
)
elif (
call_type == "aretrieve_batch"
or call_type == "retrieve_batch"
or call_type == CallTypes.aretrieve_batch
or call_type == CallTypes.retrieve_batch
):
return batch_cost_calculator(
usage=usage_block, model=model, custom_llm_provider=custom_llm_provider
)
elif call_type == "atranscription" or call_type == "transcription":
return openai_cost_per_second(
model=model,
@ -960,3 +969,54 @@ def default_image_cost_calculator(
)
return cost_info["input_cost_per_pixel"] * height * width * n
def batch_cost_calculator(
usage: Usage,
model: str,
custom_llm_provider: Optional[str] = None,
) -> Tuple[float, float]:
"""
Calculate the cost of a batch job
"""
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
verbose_logger.info(
"Calculating batch cost per token. model=%s, custom_llm_provider=%s",
model,
custom_llm_provider,
)
try:
model_info: Optional[ModelInfo] = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
except Exception:
model_info = None
if not model_info:
return 0.0, 0.0
input_cost_per_token_batches = model_info.get("input_cost_per_token_batches")
input_cost_per_token = model_info.get("input_cost_per_token")
output_cost_per_token_batches = model_info.get("output_cost_per_token_batches")
output_cost_per_token = model_info.get("output_cost_per_token")
total_prompt_cost = 0.0
total_completion_cost = 0.0
if input_cost_per_token_batches:
total_prompt_cost = usage.prompt_tokens * input_cost_per_token_batches
elif input_cost_per_token:
total_prompt_cost = (
usage.prompt_tokens * (input_cost_per_token) / 2
) # batch cost is usually half of the regular token cost
if output_cost_per_token_batches:
total_completion_cost = usage.completion_tokens * output_cost_per_token_batches
elif output_cost_per_token:
total_completion_cost = (
usage.completion_tokens * (output_cost_per_token) / 2
) # batch cost is usually half of the regular token cost
return total_prompt_cost, total_completion_cost

View file

@ -1613,11 +1613,12 @@ class Logging(LiteLLMLoggingBaseClass):
result, LiteLLMBatch
):
response_cost, batch_usage = await _handle_completed_batch(
response_cost, batch_usage, batch_models = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
)
result._hidden_params["response_cost"] = response_cost
result._hidden_params["batch_models"] = batch_models
result.usage = batch_usage
start_time, end_time, result = self._success_handler_helper_fn(
@ -3213,6 +3214,7 @@ class StandardLoggingPayloadSetup:
response_cost=None,
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
)
if hidden_params is not None:
for key in StandardLoggingHiddenParams.__annotations__.keys():
@ -3326,6 +3328,7 @@ def get_standard_logging_object_payload(
api_base=None,
response_cost=None,
litellm_overhead_time_ms=None,
batch_models=None,
)
)
@ -3610,6 +3613,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
response_cost=None,
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
)
# Convert numeric values to appropriate types

View file

@ -1897,6 +1897,7 @@ class SpendLogsMetadata(TypedDict):
applied_guardrails: Optional[List[str]]
status: StandardLoggingPayloadStatus
proxy_server_request: Optional[str]
batch_models: Optional[List[str]]
error_information: Optional[StandardLoggingPayloadErrorInformation]

View file

@ -35,7 +35,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
def _get_spend_logs_metadata(
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
metadata: Optional[dict],
applied_guardrails: Optional[List[str]] = None,
batch_models: Optional[List[str]] = None,
) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
@ -52,6 +54,7 @@ def _get_spend_logs_metadata(
status=None or "success",
error_information=None,
proxy_server_request=None,
batch_models=None,
)
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
@ -67,7 +70,7 @@ def _get_spend_logs_metadata(
}
)
clean_metadata["applied_guardrails"] = applied_guardrails
clean_metadata["batch_models"] = batch_models
return clean_metadata
@ -192,6 +195,11 @@ def get_logging_payload( # noqa: PLR0915
if standard_logging_payload is not None
else None
),
batch_models=(
standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
if standard_logging_payload is not None
else None
),
)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]

View file

@ -117,6 +117,8 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
input_cost_per_audio_per_second: Optional[float] # only for vertex ai models
input_cost_per_video_per_second: Optional[float] # only for vertex ai models
input_cost_per_second: Optional[float] # for OpenAI Speech models
input_cost_per_token_batches: Optional[float]
output_cost_per_token_batches: Optional[float]
output_cost_per_token: Required[float]
output_cost_per_character: Optional[float] # only for vertex ai models
output_cost_per_audio_token: Optional[float]
@ -213,6 +215,8 @@ CallTypesLiteral = Literal[
"acreate_batch",
"pass_through_endpoint",
"anthropic_messages",
"aretrieve_batch",
"retrieve_batch",
]
@ -1585,6 +1589,7 @@ class StandardLoggingHiddenParams(TypedDict):
response_cost: Optional[str]
litellm_overhead_time_ms: Optional[float]
additional_headers: Optional[StandardLoggingAdditionalHeaders]
batch_models: Optional[List[str]]
class StandardLoggingModelInformation(TypedDict):

View file

@ -4408,6 +4408,12 @@ def _get_model_info_helper( # noqa: PLR0915
input_cost_per_audio_token=_model_info.get(
"input_cost_per_audio_token", None
),
input_cost_per_token_batches=_model_info.get(
"input_cost_per_token_batches"
),
output_cost_per_token_batches=_model_info.get(
"output_cost_per_token_batches"
),
output_cost_per_token=_output_cost_per_token,
output_cost_per_audio_token=_model_info.get(
"output_cost_per_audio_token", None

View file

@ -9,7 +9,7 @@
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,