mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Fix batches api cost tracking + Log batch models in spend logs / standard logging payload (#9077)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 42s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 42s
* feat(batches/): fix batch cost calculation - ensure it's accurate use the correct cost value - prev. defaulting to non-batch cost * feat(batch_utils.py): log batch models to spend logs + standard logging payload makes it easy to understand how cost was calculated * fix: fix stored payload for test * test: fix test
This commit is contained in:
parent
8c049dfffc
commit
4330ef8e81
8 changed files with 110 additions and 7 deletions
|
@ -4,13 +4,13 @@ from typing import Any, List, Literal, Tuple
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.types.utils import CallTypes, Usage
|
||||
|
||||
|
||||
async def _handle_completed_batch(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
|
||||
) -> Tuple[float, Usage]:
|
||||
) -> Tuple[float, Usage, List[str]]:
|
||||
"""Helper function to process a completed batch and handle logging"""
|
||||
# Get batch results
|
||||
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||
|
@ -27,7 +27,25 @@ async def _handle_completed_batch(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
return batch_cost, batch_usage
|
||||
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
|
||||
|
||||
return batch_cost, batch_usage, batch_models
|
||||
|
||||
|
||||
def _get_batch_models_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the models from the file content
|
||||
"""
|
||||
batch_models = []
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
_model = _response_body.get("model")
|
||||
if _model:
|
||||
batch_models.append(_model)
|
||||
return batch_models
|
||||
|
||||
|
||||
async def _batch_cost_calculator(
|
||||
|
@ -105,6 +123,7 @@ def _get_batch_job_cost_from_file_content(
|
|||
total_cost += litellm.completion_cost(
|
||||
completion_response=_response_body,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
call_type=CallTypes.aretrieve_batch.value,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
|
|
|
@ -239,6 +239,15 @@ def cost_per_token( # noqa: PLR0915
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
billed_units=rerank_billed_units,
|
||||
)
|
||||
elif (
|
||||
call_type == "aretrieve_batch"
|
||||
or call_type == "retrieve_batch"
|
||||
or call_type == CallTypes.aretrieve_batch
|
||||
or call_type == CallTypes.retrieve_batch
|
||||
):
|
||||
return batch_cost_calculator(
|
||||
usage=usage_block, model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
elif call_type == "atranscription" or call_type == "transcription":
|
||||
return openai_cost_per_second(
|
||||
model=model,
|
||||
|
@ -960,3 +969,54 @@ def default_image_cost_calculator(
|
|||
)
|
||||
|
||||
return cost_info["input_cost_per_pixel"] * height * width * n
|
||||
|
||||
|
||||
def batch_cost_calculator(
|
||||
usage: Usage,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculate the cost of a batch job
|
||||
"""
|
||||
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
"Calculating batch cost per token. model=%s, custom_llm_provider=%s",
|
||||
model,
|
||||
custom_llm_provider,
|
||||
)
|
||||
|
||||
try:
|
||||
model_info: Optional[ModelInfo] = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
model_info = None
|
||||
|
||||
if not model_info:
|
||||
return 0.0, 0.0
|
||||
|
||||
input_cost_per_token_batches = model_info.get("input_cost_per_token_batches")
|
||||
input_cost_per_token = model_info.get("input_cost_per_token")
|
||||
output_cost_per_token_batches = model_info.get("output_cost_per_token_batches")
|
||||
output_cost_per_token = model_info.get("output_cost_per_token")
|
||||
total_prompt_cost = 0.0
|
||||
total_completion_cost = 0.0
|
||||
if input_cost_per_token_batches:
|
||||
total_prompt_cost = usage.prompt_tokens * input_cost_per_token_batches
|
||||
elif input_cost_per_token:
|
||||
total_prompt_cost = (
|
||||
usage.prompt_tokens * (input_cost_per_token) / 2
|
||||
) # batch cost is usually half of the regular token cost
|
||||
if output_cost_per_token_batches:
|
||||
total_completion_cost = usage.completion_tokens * output_cost_per_token_batches
|
||||
elif output_cost_per_token:
|
||||
total_completion_cost = (
|
||||
usage.completion_tokens * (output_cost_per_token) / 2
|
||||
) # batch cost is usually half of the regular token cost
|
||||
|
||||
return total_prompt_cost, total_completion_cost
|
||||
|
|
|
@ -1613,11 +1613,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
result, LiteLLMBatch
|
||||
):
|
||||
|
||||
response_cost, batch_usage = await _handle_completed_batch(
|
||||
response_cost, batch_usage, batch_models = await _handle_completed_batch(
|
||||
batch=result, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
|
||||
result._hidden_params["response_cost"] = response_cost
|
||||
result._hidden_params["batch_models"] = batch_models
|
||||
result.usage = batch_usage
|
||||
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
|
@ -3213,6 +3214,7 @@ class StandardLoggingPayloadSetup:
|
|||
response_cost=None,
|
||||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
)
|
||||
if hidden_params is not None:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
|
@ -3326,6 +3328,7 @@ def get_standard_logging_object_payload(
|
|||
api_base=None,
|
||||
response_cost=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -3610,6 +3613,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
|||
response_cost=None,
|
||||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
)
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
|
|
|
@ -1897,6 +1897,7 @@ class SpendLogsMetadata(TypedDict):
|
|||
applied_guardrails: Optional[List[str]]
|
||||
status: StandardLoggingPayloadStatus
|
||||
proxy_server_request: Optional[str]
|
||||
batch_models: Optional[List[str]]
|
||||
error_information: Optional[StandardLoggingPayloadErrorInformation]
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
|||
|
||||
|
||||
def _get_spend_logs_metadata(
|
||||
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
|
||||
metadata: Optional[dict],
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
batch_models: Optional[List[str]] = None,
|
||||
) -> SpendLogsMetadata:
|
||||
if metadata is None:
|
||||
return SpendLogsMetadata(
|
||||
|
@ -52,6 +54,7 @@ def _get_spend_logs_metadata(
|
|||
status=None or "success",
|
||||
error_information=None,
|
||||
proxy_server_request=None,
|
||||
batch_models=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
|
@ -67,7 +70,7 @@ def _get_spend_logs_metadata(
|
|||
}
|
||||
)
|
||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||
|
||||
clean_metadata["batch_models"] = batch_models
|
||||
return clean_metadata
|
||||
|
||||
|
||||
|
@ -192,6 +195,11 @@ def get_logging_payload( # noqa: PLR0915
|
|||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
batch_models=(
|
||||
standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
|
|
|
@ -117,6 +117,8 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
|
|||
input_cost_per_audio_per_second: Optional[float] # only for vertex ai models
|
||||
input_cost_per_video_per_second: Optional[float] # only for vertex ai models
|
||||
input_cost_per_second: Optional[float] # for OpenAI Speech models
|
||||
input_cost_per_token_batches: Optional[float]
|
||||
output_cost_per_token_batches: Optional[float]
|
||||
output_cost_per_token: Required[float]
|
||||
output_cost_per_character: Optional[float] # only for vertex ai models
|
||||
output_cost_per_audio_token: Optional[float]
|
||||
|
@ -213,6 +215,8 @@ CallTypesLiteral = Literal[
|
|||
"acreate_batch",
|
||||
"pass_through_endpoint",
|
||||
"anthropic_messages",
|
||||
"aretrieve_batch",
|
||||
"retrieve_batch",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1585,6 +1589,7 @@ class StandardLoggingHiddenParams(TypedDict):
|
|||
response_cost: Optional[str]
|
||||
litellm_overhead_time_ms: Optional[float]
|
||||
additional_headers: Optional[StandardLoggingAdditionalHeaders]
|
||||
batch_models: Optional[List[str]]
|
||||
|
||||
|
||||
class StandardLoggingModelInformation(TypedDict):
|
||||
|
|
|
@ -4408,6 +4408,12 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
input_cost_per_audio_token=_model_info.get(
|
||||
"input_cost_per_audio_token", None
|
||||
),
|
||||
input_cost_per_token_batches=_model_info.get(
|
||||
"input_cost_per_token_batches"
|
||||
),
|
||||
output_cost_per_token_batches=_model_info.get(
|
||||
"output_cost_per_token_batches"
|
||||
),
|
||||
output_cost_per_token=_output_cost_per_token,
|
||||
output_cost_per_audio_token=_model_info.get(
|
||||
"output_cost_per_audio_token", None
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
"model": "gpt-4o",
|
||||
"user": "",
|
||||
"team_id": "",
|
||||
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"metadata": "{\"applied_guardrails\": [], \"batch_models\": null, \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
|
||||
"cache_key": "Cache OFF",
|
||||
"spend": 0.00022500000000000002,
|
||||
"total_tokens": 30,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue