Fix batches api cost tracking + Log batch models in spend logs / standard logging payload (#9077)

* feat(batches/): fix batch cost calculation - ensure it's accurate

use the correct cost value - prev. defaulting to non-batch cost

* feat(batch_utils.py): log batch models to spend logs + standard logging payload

makes it easy to understand how cost was calculated

* fix: fix stored payload for test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-03-08 11:47:25 -08:00 committed by GitHub
parent 048ff931be
commit 44c9eef64f
8 changed files with 110 additions and 7 deletions

View file

@ -4,13 +4,13 @@ from typing import Any, List, Literal, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import Batch
from litellm.types.utils import Usage
from litellm.types.utils import CallTypes, Usage
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
) -> Tuple[float, Usage]:
) -> Tuple[float, Usage, List[str]]:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
@ -27,7 +27,25 @@ async def _handle_completed_batch(
custom_llm_provider=custom_llm_provider,
)
return batch_cost, batch_usage
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
return batch_cost, batch_usage, batch_models
def _get_batch_models_from_file_content(
file_content_dictionary: List[dict],
) -> List[str]:
"""
Get the models from the file content
"""
batch_models = []
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
_model = _response_body.get("model")
if _model:
batch_models.append(_model)
return batch_models
async def _batch_cost_calculator(
@ -105,6 +123,7 @@ def _get_batch_job_cost_from_file_content(
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost