Litellm dev 03 04 2025 p3 (#8997)

* fix(core_helpers.py): handle litellm_metadata instead of 'metadata'

* feat(batches/): ensure batches logs are written to db

makes batches response dict compatible

* fix(cost_calculator.py): handle batch response being a dictionary

* fix(batches/main.py): modify retrieve endpoints to use @client decorator

enables logging to work on retrieve call

* fix(batches/main.py): fix retrieve batch response type to be 'dict' compatible

* fix(spend_tracking_utils.py): send unique uuid for retrieve batch call type

create batch and retrieve batch share the same id

* fix(spend_tracking_utils.py): prevent duplicate retrieve batch calls from being double counted

* refactor(batches/): refactor cost tracking for batches - do it on retrieve, and within the established litellm_logging pipeline

ensures cost is always logged to db

* fix: fix linting errors

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-04 21:58:03 -08:00 committed by GitHub
parent f2a9d67e05
commit b43b8dc21c
17 changed files with 314 additions and 219 deletions

View file

@ -1,76 +1,16 @@
import asyncio
import datetime
import json import json
import threading from typing import Any, List, Literal, Tuple
from typing import Any, List, Literal, Optional
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.constants import (
BATCH_STATUS_POLL_INTERVAL_SECONDS,
BATCH_STATUS_POLL_MAX_ATTEMPTS,
)
from litellm.files.main import afile_content
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import Batch from litellm.types.llms.openai import Batch
from litellm.types.utils import StandardLoggingPayload, Usage from litellm.types.utils import Usage
async def batches_async_logging(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
logging_obj: Optional[LiteLLMLoggingObj] = None,
**kwargs,
):
"""
Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens
Polls retrieve_batch until it returns a batch with status "completed" or "failed"
"""
from .main import aretrieve_batch
verbose_logger.debug(
".....in _batches_async_logging... polling retrieve to get batch status"
)
if logging_obj is None:
raise ValueError(
"logging_obj is None cannot calculate cost / log batch creation event"
)
for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS):
try:
start_time = datetime.datetime.now()
batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider)
verbose_logger.debug(
"in _batches_async_logging... batch status= %s", batch.status
)
if batch.status == "completed":
end_time = datetime.datetime.now()
await _handle_completed_batch(
batch=batch,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
start_time=start_time,
end_time=end_time,
**kwargs,
)
break
elif batch.status == "failed":
pass
except Exception as e:
verbose_logger.error("error in batches_async_logging", e)
await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS)
async def _handle_completed_batch( async def _handle_completed_batch(
batch: Batch, batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"], custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
logging_obj: LiteLLMLoggingObj, ) -> Tuple[float, Usage]:
start_time: datetime.datetime,
end_time: datetime.datetime,
**kwargs,
) -> None:
"""Helper function to process a completed batch and handle logging""" """Helper function to process a completed batch and handle logging"""
# Get batch results # Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary( file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
@ -87,52 +27,7 @@ async def _handle_completed_batch(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
# Handle logging return batch_cost, batch_usage
await _log_completed_batch(
logging_obj=logging_obj,
batch_usage=batch_usage,
batch_cost=batch_cost,
start_time=start_time,
end_time=end_time,
**kwargs,
)
async def _log_completed_batch(
logging_obj: LiteLLMLoggingObj,
batch_usage: Usage,
batch_cost: float,
start_time: datetime.datetime,
end_time: datetime.datetime,
**kwargs,
) -> None:
"""Helper function to handle all logging operations for a completed batch"""
logging_obj.call_type = "batch_success"
standard_logging_object = _create_standard_logging_object_for_completed_batch(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
batch_usage_object=batch_usage,
response_cost=batch_cost,
)
logging_obj.model_call_details["standard_logging_object"] = standard_logging_object
# Launch async and sync logging handlers
asyncio.create_task(
logging_obj.async_success_handler(
result=None,
start_time=start_time,
end_time=end_time,
cache_hit=None,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(None, start_time, end_time),
).start()
async def _batch_cost_calculator( async def _batch_cost_calculator(
@ -159,6 +54,8 @@ async def _get_batch_output_file_content_as_dictionary(
""" """
Get the batch output file content as a list of dictionaries Get the batch output file content as a list of dictionaries
""" """
from litellm.files.main import afile_content
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval") raise ValueError("Vertex AI does not support file content retrieval")
@ -264,30 +161,3 @@ def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
""" """
_response: dict = batch_job_output_file.get("response", None) or {} _response: dict = batch_job_output_file.get("response", None) or {}
return _response.get("status_code", None) == 200 return _response.get("status_code", None) == 200
def _create_standard_logging_object_for_completed_batch(
kwargs: dict,
start_time: datetime.datetime,
end_time: datetime.datetime,
logging_obj: LiteLLMLoggingObj,
batch_usage_object: Usage,
response_cost: float,
) -> StandardLoggingPayload:
"""
Create a standard logging object for a completed batch
"""
standard_logging_object = logging_obj.model_call_details.get(
"standard_logging_object", None
)
if standard_logging_object is None:
raise ValueError("unable to create standard logging object for completed batch")
# Add Completed Batch Job Usage and Response Cost
standard_logging_object["call_type"] = "batch_success"
standard_logging_object["response_cost"] = response_cost
standard_logging_object["total_tokens"] = batch_usage_object.total_tokens
standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens
standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens
return standard_logging_object

View file

@ -31,10 +31,9 @@ from litellm.types.llms.openai import (
RetrieveBatchRequest, RetrieveBatchRequest,
) )
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LiteLLMBatch
from litellm.utils import client, get_litellm_params, supports_httpx_timeout from litellm.utils import client, get_litellm_params, supports_httpx_timeout
from .batch_utils import batches_async_logging
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI() openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI() azure_batches_instance = AzureBatchesAPI()
@ -85,17 +84,6 @@ async def acreate_batch(
else: else:
response = init_response response = init_response
# Start async logging job
if response is not None:
asyncio.create_task(
batches_async_logging(
logging_obj=kwargs.get("litellm_logging_obj", None),
batch_id=response.id,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
)
return response return response
except Exception as e: except Exception as e:
raise e raise e
@ -111,7 +99,7 @@ def create_batch(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
""" """
Creates and executes a batch from an uploaded file of request Creates and executes a batch from an uploaded file of request
@ -119,21 +107,26 @@ def create_batch(
""" """
try: try:
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_call_id = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
_is_async = kwargs.pop("acreate_batch", False) is True _is_async = kwargs.pop("acreate_batch", False) is True
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables( litellm_logging_obj.update_environment_variables(
model=None, model=None,
user=None, user=None,
optional_params=optional_params.model_dump(), optional_params=optional_params.model_dump(),
litellm_params=litellm_params, litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
**optional_params.model_dump(exclude_unset=True),
},
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
@ -261,7 +254,7 @@ def create_batch(
response=httpx.Response( response=httpx.Response(
status_code=400, status_code=400,
content="Unsupported provider", content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return response return response
@ -269,6 +262,7 @@ def create_batch(
raise e raise e
@client
async def aretrieve_batch( async def aretrieve_batch(
batch_id: str, batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@ -276,7 +270,7 @@ async def aretrieve_batch(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> Batch: ) -> LiteLLMBatch:
""" """
Async: Retrieves a batch. Async: Retrieves a batch.
@ -310,6 +304,7 @@ async def aretrieve_batch(
raise e raise e
@client
def retrieve_batch( def retrieve_batch(
batch_id: str, batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@ -317,7 +312,7 @@ def retrieve_batch(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
""" """
Retrieves a batch. Retrieves a batch.
@ -325,9 +320,23 @@ def retrieve_batch(
""" """
try: try:
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if ( if (
timeout is not None timeout is not None

View file

@ -399,9 +399,12 @@ def _select_model_name_for_cost_calc(
if base_model is not None: if base_model is not None:
return_model = base_model return_model = base_model
completion_response_model: Optional[str] = getattr( completion_response_model: Optional[str] = None
completion_response, "model", None if completion_response is not None:
) if isinstance(completion_response, BaseModel):
completion_response_model = getattr(completion_response, "model", None)
elif isinstance(completion_response, dict):
completion_response_model = completion_response.get("model", None)
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None) hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
if completion_response_model is None and hidden_params is not None: if completion_response_model is None and hidden_params is not None:
if ( if (

View file

@ -816,7 +816,7 @@ def file_content(
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_content'. Only 'openai' and 'azure' are supported.".format( message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
custom_llm_provider custom_llm_provider
), ),
model="n/a", model="n/a",

View file

@ -74,7 +74,16 @@ def get_litellm_metadata_from_kwargs(kwargs: dict):
""" """
Helper to get litellm metadata from all litellm request kwargs Helper to get litellm metadata from all litellm request kwargs
""" """
return kwargs.get("litellm_params", {}).get("metadata", {}) litellm_params = kwargs.get("litellm_params", {})
if litellm_params:
metadata = litellm_params.get("metadata", {})
litellm_metadata = litellm_params.get("litellm_metadata", {})
if litellm_metadata:
return litellm_metadata
elif metadata:
return metadata
return {}
# Helper functions used for OTEL logging # Helper functions used for OTEL logging

View file

@ -25,6 +25,7 @@ from litellm import (
turn_off_message_logging, turn_off_message_logging,
) )
from litellm._logging import _is_debugging_on, verbose_logger from litellm._logging import _is_debugging_on, verbose_logger
from litellm.batches.batch_utils import _handle_completed_batch
from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching import DualCache, InMemoryCache
from litellm.caching.caching_handler import LLMCachingHandler from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc from litellm.cost_calculator import _select_model_name_for_cost_calc
@ -50,6 +51,7 @@ from litellm.types.utils import (
CallTypes, CallTypes,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
LiteLLMBatch,
LiteLLMLoggingBaseClass, LiteLLMLoggingBaseClass,
ModelResponse, ModelResponse,
ModelResponseStream, ModelResponseStream,
@ -871,6 +873,24 @@ class Logging(LiteLLMLoggingBaseClass):
return None return None
async def _response_cost_calculator_async(
self,
result: Union[
ModelResponse,
ModelResponseStream,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
HttpxBinaryResponseContent,
RerankResponse,
Batch,
FineTuningJob,
],
cache_hit: Optional[bool] = None,
) -> Optional[float]:
return self._response_cost_calculator(result=result, cache_hit=cache_hit)
def should_run_callback( def should_run_callback(
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
) -> bool: ) -> bool:
@ -928,8 +948,8 @@ class Logging(LiteLLMLoggingBaseClass):
or isinstance(result, TextCompletionResponse) or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse) or isinstance(result, RerankResponse)
or isinstance(result, Batch)
or isinstance(result, FineTuningJob) or isinstance(result, FineTuningJob)
or isinstance(result, LiteLLMBatch)
): ):
## HIDDEN PARAMS ## ## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {}) hidden_params = getattr(result, "_hidden_params", {})
@ -1525,6 +1545,19 @@ class Logging(LiteLLMLoggingBaseClass):
print_verbose( print_verbose(
"Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit) "Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit)
) )
## CALCULATE COST FOR BATCH JOBS
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch
):
response_cost, batch_usage = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
)
result._hidden_params["response_cost"] = response_cost
result.usage = batch_usage
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
@ -1532,6 +1565,7 @@ class Logging(LiteLLMLoggingBaseClass):
cache_hit=cache_hit, cache_hit=cache_hit,
standard_logging_object=kwargs.get("standard_logging_object", None), standard_logging_object=kwargs.get("standard_logging_object", None),
) )
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
if "async_complete_streaming_response" in self.model_call_details: if "async_complete_streaming_response" in self.model_call_details:
return # break out of this. return # break out of this.

View file

@ -2,7 +2,7 @@
Azure Batches API Handler Azure Batches API Handler
""" """
from typing import Any, Coroutine, Optional, Union from typing import Any, Coroutine, Optional, Union, cast
import httpx import httpx
@ -14,6 +14,7 @@ from litellm.types.llms.openai import (
CreateBatchRequest, CreateBatchRequest,
RetrieveBatchRequest, RetrieveBatchRequest,
) )
from litellm.types.utils import LiteLLMBatch
class AzureBatchesAPI: class AzureBatchesAPI:
@ -64,9 +65,9 @@ class AzureBatchesAPI:
self, self,
create_batch_data: CreateBatchRequest, create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI, azure_client: AsyncAzureOpenAI,
) -> Batch: ) -> LiteLLMBatch:
response = await azure_client.batches.create(**create_batch_data) response = await azure_client.batches.create(**create_batch_data)
return response return LiteLLMBatch(**response.model_dump())
def create_batch( def create_batch(
self, self,
@ -78,7 +79,7 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client( self.get_azure_openai_client(
api_key=api_key, api_key=api_key,
@ -103,16 +104,16 @@ class AzureBatchesAPI:
return self.acreate_batch( # type: ignore return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client create_batch_data=create_batch_data, azure_client=azure_client
) )
response = azure_client.batches.create(**create_batch_data) response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data)
return response return LiteLLMBatch(**response.model_dump())
async def aretrieve_batch( async def aretrieve_batch(
self, self,
retrieve_batch_data: RetrieveBatchRequest, retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI, client: AsyncAzureOpenAI,
) -> Batch: ) -> LiteLLMBatch:
response = await client.batches.retrieve(**retrieve_batch_data) response = await client.batches.retrieve(**retrieve_batch_data)
return response return LiteLLMBatch(**response.model_dump())
def retrieve_batch( def retrieve_batch(
self, self,
@ -149,8 +150,10 @@ class AzureBatchesAPI:
return self.aretrieve_batch( # type: ignore return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client retrieve_batch_data=retrieve_batch_data, client=azure_client
) )
response = azure_client.batches.retrieve(**retrieve_batch_data) response = cast(AzureOpenAI, azure_client).batches.retrieve(
return response **retrieve_batch_data
)
return LiteLLMBatch(**response.model_dump())
async def acancel_batch( async def acancel_batch(
self, self,

View file

@ -37,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENT
from litellm.types.utils import ( from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
LiteLLMBatch,
ModelResponse, ModelResponse,
ModelResponseStream, ModelResponseStream,
) )
@ -1755,9 +1756,9 @@ class OpenAIBatchesAPI(BaseLLM):
self, self,
create_batch_data: CreateBatchRequest, create_batch_data: CreateBatchRequest,
openai_client: AsyncOpenAI, openai_client: AsyncOpenAI,
) -> Batch: ) -> LiteLLMBatch:
response = await openai_client.batches.create(**create_batch_data) response = await openai_client.batches.create(**create_batch_data)
return response return LiteLLMBatch(**response.model_dump())
def create_batch( def create_batch(
self, self,
@ -1769,7 +1770,7 @@ class OpenAIBatchesAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1792,17 +1793,18 @@ class OpenAIBatchesAPI(BaseLLM):
return self.acreate_batch( # type: ignore return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, openai_client=openai_client create_batch_data=create_batch_data, openai_client=openai_client
) )
response = openai_client.batches.create(**create_batch_data) response = cast(OpenAI, openai_client).batches.create(**create_batch_data)
return response
return LiteLLMBatch(**response.model_dump())
async def aretrieve_batch( async def aretrieve_batch(
self, self,
retrieve_batch_data: RetrieveBatchRequest, retrieve_batch_data: RetrieveBatchRequest,
openai_client: AsyncOpenAI, openai_client: AsyncOpenAI,
) -> Batch: ) -> LiteLLMBatch:
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data) verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
response = await openai_client.batches.retrieve(**retrieve_batch_data) response = await openai_client.batches.retrieve(**retrieve_batch_data)
return response return LiteLLMBatch(**response.model_dump())
def retrieve_batch( def retrieve_batch(
self, self,
@ -1837,8 +1839,8 @@ class OpenAIBatchesAPI(BaseLLM):
return self.aretrieve_batch( # type: ignore return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
) )
response = openai_client.batches.retrieve(**retrieve_batch_data) response = cast(OpenAI, openai_client).batches.retrieve(**retrieve_batch_data)
return response return LiteLLMBatch(**response.model_dump())
async def acancel_batch( async def acancel_batch(
self, self,

View file

@ -9,11 +9,12 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.openai import Batch, CreateBatchRequest from litellm.types.llms.openai import CreateBatchRequest
from litellm.types.llms.vertex_ai import ( from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES, VERTEX_CREDENTIALS_TYPES,
VertexAIBatchPredictionJob, VertexAIBatchPredictionJob,
) )
from litellm.types.utils import LiteLLMBatch
from .transformation import VertexAIBatchTransformation from .transformation import VertexAIBatchTransformation
@ -33,7 +34,7 @@ class VertexAIBatchPrediction(VertexLLM):
vertex_location: Optional[str], vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
) -> Union[Batch, Coroutine[Any, Any, Batch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client() sync_handler = _get_httpx_client()
@ -101,7 +102,7 @@ class VertexAIBatchPrediction(VertexLLM):
vertex_batch_request: VertexAIBatchPredictionJob, vertex_batch_request: VertexAIBatchPredictionJob,
api_base: str, api_base: str,
headers: Dict[str, str], headers: Dict[str, str],
) -> Batch: ) -> LiteLLMBatch:
client = get_async_httpx_client( client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI, llm_provider=litellm.LlmProviders.VERTEX_AI,
) )
@ -138,7 +139,7 @@ class VertexAIBatchPrediction(VertexLLM):
vertex_location: Optional[str], vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
) -> Union[Batch, Coroutine[Any, Any, Batch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client() sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token( access_token, project_id = self._ensure_access_token(
@ -199,7 +200,7 @@ class VertexAIBatchPrediction(VertexLLM):
self, self,
api_base: str, api_base: str,
headers: Dict[str, str], headers: Dict[str, str],
) -> Batch: ) -> LiteLLMBatch:
client = get_async_httpx_client( client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI, llm_provider=litellm.LlmProviders.VERTEX_AI,
) )

View file

@ -4,8 +4,9 @@ from typing import Dict
from litellm.llms.vertex_ai.common_utils import ( from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime, _convert_vertex_datetime_to_openai_datetime,
) )
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
from litellm.types.llms.vertex_ai import * from litellm.types.llms.vertex_ai import *
from litellm.types.utils import LiteLLMBatch
class VertexAIBatchTransformation: class VertexAIBatchTransformation:
@ -47,8 +48,8 @@ class VertexAIBatchTransformation:
@classmethod @classmethod
def transform_vertex_ai_batch_response_to_openai_batch_response( def transform_vertex_ai_batch_response_to_openai_batch_response(
cls, response: VertexBatchPredictionResponse cls, response: VertexBatchPredictionResponse
) -> Batch: ) -> LiteLLMBatch:
return Batch( return LiteLLMBatch(
id=cls._get_batch_id_from_vertex_ai_batch_response(response), id=cls._get_batch_id_from_vertex_ai_batch_response(response),
completion_window="24hrs", completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime( created_at=_convert_vertex_datetime_to_openai_datetime(

View file

@ -1,9 +1,13 @@
model_list: model_list:
- model_name: my-langfuse-model - model_name: openai/gpt-4o
litellm_params: litellm_params:
model: langfuse/openai-model model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: openai-model
litellm_params: files_settings:
model: openai/gpt-3.5-turbo - custom_llm_provider: azure
api_key: os.environ/OPENAI_API_KEY api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
general_settings:
store_prompts_in_spend_logs: true

View file

@ -2,10 +2,10 @@
# /v1/batches Endpoints # /v1/batches Endpoints
import asyncio
###################################################################### ######################################################################
from typing import Dict, Optional import asyncio
from typing import Dict, Optional, cast
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
@ -199,8 +199,11 @@ async def retrieve_batch(
``` ```
""" """
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers, get_custom_headers,
llm_router, llm_router,
proxy_config,
proxy_logging_obj, proxy_logging_obj,
version, version,
) )
@ -212,6 +215,23 @@ async def retrieve_batch(
batch_id=batch_id, batch_id=batch_id,
) )
data = cast(dict, _retrieve_batch_request)
# setup logging
data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
if litellm.enable_loadbalancing_on_batch_endpoints is True: if litellm.enable_loadbalancing_on_batch_endpoints is True:
if llm_router is None: if llm_router is None:
raise HTTPException( raise HTTPException(
@ -221,7 +241,7 @@ async def retrieve_batch(
}, },
) )
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore response = await llm_router.aretrieve_batch(**data) # type: ignore
else: else:
custom_llm_provider = ( custom_llm_provider = (
provider provider
@ -229,7 +249,7 @@ async def retrieve_batch(
or "openai" or "openai"
) )
response = await litellm.aretrieve_batch( response = await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **_retrieve_batch_request # type: ignore custom_llm_provider=custom_llm_provider, **data # type: ignore
) )
### ALERTING ### ### ALERTING ###

View file

@ -11,6 +11,7 @@ import uuid
from typing import TYPE_CHECKING, List, Optional, Union, cast from typing import TYPE_CHECKING, List, Optional, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger

View file

@ -955,6 +955,9 @@ def _set_spend_logs_payload(
prisma_client: PrismaClient, prisma_client: PrismaClient,
spend_logs_url: Optional[str] = None, spend_logs_url: Optional[str] = None,
): ):
verbose_proxy_logger.info(
"Writing spend log to db - request_id: {}".format(payload.get("request_id"))
)
if prisma_client is not None and spend_logs_url is not None: if prisma_client is not None and spend_logs_url is not None:
if isinstance(payload["startTime"], datetime): if isinstance(payload["startTime"], datetime):
payload["startTime"] = payload["startTime"].isoformat() payload["startTime"] = payload["startTime"].isoformat()
@ -1056,7 +1059,6 @@ async def update_database( # noqa: PLR0915
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
payload["spend"] = response_cost payload["spend"] = response_cost
prisma_client = _set_spend_logs_payload( prisma_client = _set_spend_logs_payload(
payload=payload, payload=payload,

View file

@ -1,9 +1,10 @@
import hashlib
import json import json
import secrets import secrets
from datetime import datetime from datetime import datetime
from datetime import datetime as dt from datetime import datetime as dt
from datetime import timezone from datetime import timezone
from typing import List, Optional, cast from typing import Any, List, Optional, cast
from pydantic import BaseModel from pydantic import BaseModel
@ -69,6 +70,42 @@ def _get_spend_logs_metadata(
return clean_metadata return clean_metadata
def generate_hash_from_response(response_obj: Any) -> str:
"""
Generate a stable hash from a response object.
Args:
response_obj: The response object to hash (can be dict, list, etc.)
Returns:
A hex string representation of the MD5 hash
"""
try:
# Create a stable JSON string of the entire response object
# Sort keys to ensure consistent ordering
json_str = json.dumps(response_obj, sort_keys=True)
# Generate a hash of the response object
unique_hash = hashlib.md5(json_str.encode()).hexdigest()
return unique_hash
except Exception:
# Return a fallback hash if serialization fails
return hashlib.md5(str(response_obj).encode()).hexdigest()
def get_spend_logs_id(
call_type: str, response_obj: dict, kwargs: dict
) -> Optional[str]:
if call_type == "aretrieve_batch":
# Generate a hash from the response object
id: Optional[str] = generate_hash_from_response(response_obj)
else:
id = cast(Optional[str], response_obj.get("id")) or cast(
Optional[str], kwargs.get("litellm_call_id")
)
return id
def get_logging_payload( # noqa: PLR0915 def get_logging_payload( # noqa: PLR0915
kwargs, response_obj, start_time, end_time kwargs, response_obj, start_time, end_time
) -> SpendLogsPayload: ) -> SpendLogsPayload:
@ -94,7 +131,15 @@ def get_logging_payload( # noqa: PLR0915
usage = cast(dict, response_obj).get("usage", None) or {} usage = cast(dict, response_obj).get("usage", None) or {}
if isinstance(usage, litellm.Usage): if isinstance(usage, litellm.Usage):
usage = dict(usage) usage = dict(usage)
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
if isinstance(response_obj, dict):
response_obj_dict = response_obj
elif isinstance(response_obj, BaseModel):
response_obj_dict = response_obj.model_dump()
else:
response_obj_dict = {}
id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
standard_logging_payload = cast( standard_logging_payload = cast(
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None) Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
) )
@ -177,14 +222,8 @@ def get_logging_payload( # noqa: PLR0915
endTime=_ensure_datetime_utc(end_time), endTime=_ensure_datetime_utc(end_time),
completionStartTime=_ensure_datetime_utc(completion_start_time), completionStartTime=_ensure_datetime_utc(completion_start_time),
model=kwargs.get("model", "") or "", model=kwargs.get("model", "") or "",
user=kwargs.get("litellm_params", {}) user=metadata.get("user_api_key_user_id", "") or "",
.get("metadata", {}) team_id=metadata.get("user_api_key_team_id", "") or "",
.get("user_api_key_user_id", "")
or "",
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", "")
or "",
metadata=json.dumps(clean_metadata), metadata=json.dumps(clean_metadata),
cache_key=cache_key, cache_key=cache_key,
spend=kwargs.get("response_cost", 0), spend=kwargs.get("response_cost", 0),
@ -314,10 +353,13 @@ def _add_proxy_server_request_to_metadata(
Only store if _should_store_prompts_and_responses_in_spend_logs() is True Only store if _should_store_prompts_and_responses_in_spend_logs() is True
""" """
if _should_store_prompts_and_responses_in_spend_logs(): if _should_store_prompts_and_responses_in_spend_logs():
_proxy_server_request = litellm_params.get("proxy_server_request", {}) _proxy_server_request = cast(
_request_body = _proxy_server_request.get("body", {}) or {} Optional[dict], litellm_params.get("proxy_server_request", {})
_request_body_json_str = json.dumps(_request_body, default=str) )
metadata["proxy_server_request"] = _request_body_json_str if _proxy_server_request is not None:
_request_body = _proxy_server_request.get("body", {}) or {}
_request_body_json_str = json.dumps(_request_body, default=str)
metadata["proxy_server_request"] = _request_body_json_str
return metadata return metadata

View file

@ -24,6 +24,7 @@ from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
from .guardrails import GuardrailEventHooks from .guardrails import GuardrailEventHooks
from .llms.openai import ( from .llms.openai import (
Batch,
ChatCompletionThinkingBlock, ChatCompletionThinkingBlock,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
@ -182,6 +183,8 @@ class CallTypes(Enum):
arealtime = "_arealtime" arealtime = "_arealtime"
create_batch = "create_batch" create_batch = "create_batch"
acreate_batch = "acreate_batch" acreate_batch = "acreate_batch"
aretrieve_batch = "aretrieve_batch"
retrieve_batch = "retrieve_batch"
pass_through = "pass_through_endpoint" pass_through = "pass_through_endpoint"
@ -1963,3 +1966,27 @@ class ProviderSpecificHeader(TypedDict):
class SelectTokenizerResponse(TypedDict): class SelectTokenizerResponse(TypedDict):
type: Literal["openai_tokenizer", "huggingface_tokenizer"] type: Literal["openai_tokenizer", "huggingface_tokenizer"]
tokenizer: Any tokenizer: Any
class LiteLLMBatch(Batch):
_hidden_params: dict = {}
usage: Optional[Usage] = None
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except Exception:
# if using pydantic v1
return self.dict()

View file

@ -2982,3 +2982,70 @@ def test_json_valid_model_cost_map():
json.loads(json_str) json.loads(json_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
assert False, f"Invalid JSON format: {str(e)}" assert False, f"Invalid JSON format: {str(e)}"
def test_batch_cost_calculator():
args = {
"completion_response": {
"choices": [
{
"content_filter_results": {
"hate": {"filtered": False, "severity": "safe"},
"protected_material_code": {
"filtered": False,
"detected": False,
},
"protected_material_text": {
"filtered": False,
"detected": False,
},
"self_harm": {"filtered": False, "severity": "safe"},
"sexual": {"filtered": False, "severity": "safe"},
"violence": {"filtered": False, "severity": "safe"},
},
"finish_reason": "stop",
"index": 0,
"logprobs": None,
"message": {
"content": 'As of my last update in October 2023, there are eight recognized planets in the solar system. They are:\n\n1. **Mercury** - The closest planet to the Sun, known for its extreme temperature fluctuations.\n2. **Venus** - Similar in size to Earth but with a thick atmosphere rich in carbon dioxide, leading to a greenhouse effect that makes it the hottest planet.\n3. **Earth** - The only planet known to support life, with a diverse environment and liquid water.\n4. **Mars** - Known as the Red Planet, it has the largest volcano and canyon in the solar system and features signs of past water.\n5. **Jupiter** - The largest planet in the solar system, known for its Great Red Spot and numerous moons.\n6. **Saturn** - Famous for its stunning rings, it is a gas giant also known for its extensive moon system.\n7. **Uranus** - An ice giant with a unique tilt, it rotates on its side and has a blue color due to methane in its atmosphere.\n8. **Neptune** - Another ice giant, known for its deep blue color and strong winds, it is the farthest planet from the Sun.\n\nPluto was previously classified as the ninth planet but was reclassified as a "dwarf planet" in 2006 by the International Astronomical Union.',
"refusal": None,
"role": "assistant",
},
}
],
"created": 1741135408,
"id": "chatcmpl-B7X96teepFM4ILP7cm4Ga62eRuV8p",
"model": "gpt-4o-mini-2024-07-18",
"object": "chat.completion",
"prompt_filter_results": [
{
"prompt_index": 0,
"content_filter_results": {
"hate": {"filtered": False, "severity": "safe"},
"jailbreak": {"filtered": False, "detected": False},
"self_harm": {"filtered": False, "severity": "safe"},
"sexual": {"filtered": False, "severity": "safe"},
"violence": {"filtered": False, "severity": "safe"},
},
}
],
"system_fingerprint": "fp_b705f0c291",
"usage": {
"completion_tokens": 278,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens": 20,
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
"total_tokens": 298,
},
},
"model": None,
}
cost = completion_cost(**args)
assert cost > 0