Litellm dev 03 04 2025 p3 (#8997)

* fix(core_helpers.py): handle litellm_metadata instead of 'metadata'

* feat(batches/): ensure batches logs are written to db

makes batches response dict compatible

* fix(cost_calculator.py): handle batch response being a dictionary

* fix(batches/main.py): modify retrieve endpoints to use @client decorator

enables logging to work on retrieve call

* fix(batches/main.py): fix retrieve batch response type to be 'dict' compatible

* fix(spend_tracking_utils.py): send unique uuid for retrieve batch call type

create batch and retrieve batch share the same id

* fix(spend_tracking_utils.py): prevent duplicate retrieve batch calls from being double counted

* refactor(batches/): refactor cost tracking for batches - do it on retrieve, and within the established litellm_logging pipeline

ensures cost is always logged to db

* fix: fix linting errors

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-04 21:58:03 -08:00 committed by GitHub
parent f2a9d67e05
commit b43b8dc21c
17 changed files with 314 additions and 219 deletions

View file

@ -1,76 +1,16 @@
import asyncio
import datetime
import json
import threading
from typing import Any, List, Literal, Optional
from typing import Any, List, Literal, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.constants import (
BATCH_STATUS_POLL_INTERVAL_SECONDS,
BATCH_STATUS_POLL_MAX_ATTEMPTS,
)
from litellm.files.main import afile_content
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import Batch
from litellm.types.utils import StandardLoggingPayload, Usage
async def batches_async_logging(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
logging_obj: Optional[LiteLLMLoggingObj] = None,
**kwargs,
):
"""
Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens
Polls retrieve_batch until it returns a batch with status "completed" or "failed"
"""
from .main import aretrieve_batch
verbose_logger.debug(
".....in _batches_async_logging... polling retrieve to get batch status"
)
if logging_obj is None:
raise ValueError(
"logging_obj is None cannot calculate cost / log batch creation event"
)
for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS):
try:
start_time = datetime.datetime.now()
batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider)
verbose_logger.debug(
"in _batches_async_logging... batch status= %s", batch.status
)
if batch.status == "completed":
end_time = datetime.datetime.now()
await _handle_completed_batch(
batch=batch,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
start_time=start_time,
end_time=end_time,
**kwargs,
)
break
elif batch.status == "failed":
pass
except Exception as e:
verbose_logger.error("error in batches_async_logging", e)
await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS)
from litellm.types.utils import Usage
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
end_time: datetime.datetime,
**kwargs,
) -> None:
) -> Tuple[float, Usage]:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
@ -87,52 +27,7 @@ async def _handle_completed_batch(
custom_llm_provider=custom_llm_provider,
)
# Handle logging
await _log_completed_batch(
logging_obj=logging_obj,
batch_usage=batch_usage,
batch_cost=batch_cost,
start_time=start_time,
end_time=end_time,
**kwargs,
)
async def _log_completed_batch(
logging_obj: LiteLLMLoggingObj,
batch_usage: Usage,
batch_cost: float,
start_time: datetime.datetime,
end_time: datetime.datetime,
**kwargs,
) -> None:
"""Helper function to handle all logging operations for a completed batch"""
logging_obj.call_type = "batch_success"
standard_logging_object = _create_standard_logging_object_for_completed_batch(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
batch_usage_object=batch_usage,
response_cost=batch_cost,
)
logging_obj.model_call_details["standard_logging_object"] = standard_logging_object
# Launch async and sync logging handlers
asyncio.create_task(
logging_obj.async_success_handler(
result=None,
start_time=start_time,
end_time=end_time,
cache_hit=None,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(None, start_time, end_time),
).start()
return batch_cost, batch_usage
async def _batch_cost_calculator(
@ -159,6 +54,8 @@ async def _get_batch_output_file_content_as_dictionary(
"""
Get the batch output file content as a list of dictionaries
"""
from litellm.files.main import afile_content
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
@ -264,30 +161,3 @@ def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
"""
_response: dict = batch_job_output_file.get("response", None) or {}
return _response.get("status_code", None) == 200
def _create_standard_logging_object_for_completed_batch(
kwargs: dict,
start_time: datetime.datetime,
end_time: datetime.datetime,
logging_obj: LiteLLMLoggingObj,
batch_usage_object: Usage,
response_cost: float,
) -> StandardLoggingPayload:
"""
Create a standard logging object for a completed batch
"""
standard_logging_object = logging_obj.model_call_details.get(
"standard_logging_object", None
)
if standard_logging_object is None:
raise ValueError("unable to create standard logging object for completed batch")
# Add Completed Batch Job Usage and Response Cost
standard_logging_object["call_type"] = "batch_success"
standard_logging_object["response_cost"] = response_cost
standard_logging_object["total_tokens"] = batch_usage_object.total_tokens
standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens
standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens
return standard_logging_object

View file

@ -31,10 +31,9 @@ from litellm.types.llms.openai import (
RetrieveBatchRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LiteLLMBatch
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
from .batch_utils import batches_async_logging
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
@ -85,17 +84,6 @@ async def acreate_batch(
else:
response = init_response
# Start async logging job
if response is not None:
asyncio.create_task(
batches_async_logging(
logging_obj=kwargs.get("litellm_logging_obj", None),
batch_id=response.id,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
)
return response
except Exception as e:
raise e
@ -111,7 +99,7 @@ def create_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Creates and executes a batch from an uploaded file of request
@ -119,21 +107,26 @@ def create_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_call_id = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
_is_async = kwargs.pop("acreate_batch", False) is True
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
**optional_params.model_dump(exclude_unset=True),
},
custom_llm_provider=custom_llm_provider,
)
@ -261,7 +254,7 @@ def create_batch(
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
@ -269,6 +262,7 @@ def create_batch(
raise e
@client
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@ -276,7 +270,7 @@ async def aretrieve_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
) -> LiteLLMBatch:
"""
Async: Retrieves a batch.
@ -310,6 +304,7 @@ async def aretrieve_batch(
raise e
@client
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@ -317,7 +312,7 @@ def retrieve_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Retrieves a batch.
@ -325,9 +320,23 @@ def retrieve_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if (
timeout is not None

View file

@ -399,9 +399,12 @@ def _select_model_name_for_cost_calc(
if base_model is not None:
return_model = base_model
completion_response_model: Optional[str] = getattr(
completion_response, "model", None
)
completion_response_model: Optional[str] = None
if completion_response is not None:
if isinstance(completion_response, BaseModel):
completion_response_model = getattr(completion_response, "model", None)
elif isinstance(completion_response, dict):
completion_response_model = completion_response.get("model", None)
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
if completion_response_model is None and hidden_params is not None:
if (

View file

@ -816,7 +816,7 @@ def file_content(
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_content'. Only 'openai' and 'azure' are supported.".format(
message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
custom_llm_provider
),
model="n/a",

View file

@ -74,7 +74,16 @@ def get_litellm_metadata_from_kwargs(kwargs: dict):
"""
Helper to get litellm metadata from all litellm request kwargs
"""
return kwargs.get("litellm_params", {}).get("metadata", {})
litellm_params = kwargs.get("litellm_params", {})
if litellm_params:
metadata = litellm_params.get("metadata", {})
litellm_metadata = litellm_params.get("litellm_metadata", {})
if litellm_metadata:
return litellm_metadata
elif metadata:
return metadata
return {}
# Helper functions used for OTEL logging

View file

@ -25,6 +25,7 @@ from litellm import (
turn_off_message_logging,
)
from litellm._logging import _is_debugging_on, verbose_logger
from litellm.batches.batch_utils import _handle_completed_batch
from litellm.caching.caching import DualCache, InMemoryCache
from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc
@ -50,6 +51,7 @@ from litellm.types.utils import (
CallTypes,
EmbeddingResponse,
ImageResponse,
LiteLLMBatch,
LiteLLMLoggingBaseClass,
ModelResponse,
ModelResponseStream,
@ -871,6 +873,24 @@ class Logging(LiteLLMLoggingBaseClass):
return None
async def _response_cost_calculator_async(
self,
result: Union[
ModelResponse,
ModelResponseStream,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
HttpxBinaryResponseContent,
RerankResponse,
Batch,
FineTuningJob,
],
cache_hit: Optional[bool] = None,
) -> Optional[float]:
return self._response_cost_calculator(result=result, cache_hit=cache_hit)
def should_run_callback(
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
) -> bool:
@ -928,8 +948,8 @@ class Logging(LiteLLMLoggingBaseClass):
or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse)
or isinstance(result, Batch)
or isinstance(result, FineTuningJob)
or isinstance(result, LiteLLMBatch)
):
## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {})
@ -1525,6 +1545,19 @@ class Logging(LiteLLMLoggingBaseClass):
print_verbose(
"Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit)
)
## CALCULATE COST FOR BATCH JOBS
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch
):
response_cost, batch_usage = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
)
result._hidden_params["response_cost"] = response_cost
result.usage = batch_usage
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
@ -1532,6 +1565,7 @@ class Logging(LiteLLMLoggingBaseClass):
cache_hit=cache_hit,
standard_logging_object=kwargs.get("standard_logging_object", None),
)
## BUILD COMPLETE STREAMED RESPONSE
if "async_complete_streaming_response" in self.model_call_details:
return # break out of this.

View file

@ -2,7 +2,7 @@
Azure Batches API Handler
"""
from typing import Any, Coroutine, Optional, Union
from typing import Any, Coroutine, Optional, Union, cast
import httpx
@ -14,6 +14,7 @@ from litellm.types.llms.openai import (
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.types.utils import LiteLLMBatch
class AzureBatchesAPI:
@ -64,9 +65,9 @@ class AzureBatchesAPI:
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> Batch:
) -> LiteLLMBatch:
response = await azure_client.batches.create(**create_batch_data)
return response
return LiteLLMBatch(**response.model_dump())
def create_batch(
self,
@ -78,7 +79,7 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
@ -103,16 +104,16 @@ class AzureBatchesAPI:
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = azure_client.batches.create(**create_batch_data)
return response
response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data)
return LiteLLMBatch(**response.model_dump())
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
) -> LiteLLMBatch:
response = await client.batches.retrieve(**retrieve_batch_data)
return response
return LiteLLMBatch(**response.model_dump())
def retrieve_batch(
self,
@ -149,8 +150,10 @@ class AzureBatchesAPI:
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = azure_client.batches.retrieve(**retrieve_batch_data)
return response
response = cast(AzureOpenAI, azure_client).batches.retrieve(
**retrieve_batch_data
)
return LiteLLMBatch(**response.model_dump())
async def acancel_batch(
self,

View file

@ -37,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENT
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
LiteLLMBatch,
ModelResponse,
ModelResponseStream,
)
@ -1755,9 +1756,9 @@ class OpenAIBatchesAPI(BaseLLM):
self,
create_batch_data: CreateBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
) -> LiteLLMBatch:
response = await openai_client.batches.create(**create_batch_data)
return response
return LiteLLMBatch(**response.model_dump())
def create_batch(
self,
@ -1769,7 +1770,7 @@ class OpenAIBatchesAPI(BaseLLM):
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1792,17 +1793,18 @@ class OpenAIBatchesAPI(BaseLLM):
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, openai_client=openai_client
)
response = openai_client.batches.create(**create_batch_data)
return response
response = cast(OpenAI, openai_client).batches.create(**create_batch_data)
return LiteLLMBatch(**response.model_dump())
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
) -> LiteLLMBatch:
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
response = await openai_client.batches.retrieve(**retrieve_batch_data)
return response
return LiteLLMBatch(**response.model_dump())
def retrieve_batch(
self,
@ -1837,8 +1839,8 @@ class OpenAIBatchesAPI(BaseLLM):
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
response = cast(OpenAI, openai_client).batches.retrieve(**retrieve_batch_data)
return LiteLLMBatch(**response.model_dump())
async def acancel_batch(
self,

View file

@ -9,11 +9,12 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.openai import Batch, CreateBatchRequest
from litellm.types.llms.openai import CreateBatchRequest
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
VertexAIBatchPredictionJob,
)
from litellm.types.utils import LiteLLMBatch
from .transformation import VertexAIBatchTransformation
@ -33,7 +34,7 @@ class VertexAIBatchPrediction(VertexLLM):
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client()
@ -101,7 +102,7 @@ class VertexAIBatchPrediction(VertexLLM):
vertex_batch_request: VertexAIBatchPredictionJob,
api_base: str,
headers: Dict[str, str],
) -> Batch:
) -> LiteLLMBatch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
@ -138,7 +139,7 @@ class VertexAIBatchPrediction(VertexLLM):
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
@ -199,7 +200,7 @@ class VertexAIBatchPrediction(VertexLLM):
self,
api_base: str,
headers: Dict[str, str],
) -> Batch:
) -> LiteLLMBatch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)

View file

@ -4,8 +4,9 @@ from typing import Dict
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
from litellm.types.llms.vertex_ai import *
from litellm.types.utils import LiteLLMBatch
class VertexAIBatchTransformation:
@ -47,8 +48,8 @@ class VertexAIBatchTransformation:
@classmethod
def transform_vertex_ai_batch_response_to_openai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> Batch:
return Batch(
) -> LiteLLMBatch:
return LiteLLMBatch(
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime(

View file

@ -1,9 +1,13 @@
model_list:
- model_name: my-langfuse-model
- model_name: openai/gpt-4o
litellm_params:
model: langfuse/openai-model
api_key: os.environ/OPENAI_API_KEY
- model_name: openai-model
litellm_params:
model: openai/gpt-3.5-turbo
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
files_settings:
- custom_llm_provider: azure
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
general_settings:
store_prompts_in_spend_logs: true

View file

@ -2,10 +2,10 @@
# /v1/batches Endpoints
import asyncio
######################################################################
from typing import Dict, Optional
import asyncio
from typing import Dict, Optional, cast
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
@ -199,8 +199,11 @@ async def retrieve_batch(
```
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
version,
)
@ -212,6 +215,23 @@ async def retrieve_batch(
batch_id=batch_id,
)
data = cast(dict, _retrieve_batch_request)
# setup logging
data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
if litellm.enable_loadbalancing_on_batch_endpoints is True:
if llm_router is None:
raise HTTPException(
@ -221,7 +241,7 @@ async def retrieve_batch(
},
)
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
response = await llm_router.aretrieve_batch(**data) # type: ignore
else:
custom_llm_provider = (
provider
@ -229,7 +249,7 @@ async def retrieve_batch(
or "openai"
)
response = await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **_retrieve_batch_request # type: ignore
custom_llm_provider=custom_llm_provider, **data # type: ignore
)
### ALERTING ###

View file

@ -11,6 +11,7 @@ import uuid
from typing import TYPE_CHECKING, List, Optional, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm
from litellm._logging import verbose_proxy_logger

View file

@ -955,6 +955,9 @@ def _set_spend_logs_payload(
prisma_client: PrismaClient,
spend_logs_url: Optional[str] = None,
):
verbose_proxy_logger.info(
"Writing spend log to db - request_id: {}".format(payload.get("request_id"))
)
if prisma_client is not None and spend_logs_url is not None:
if isinstance(payload["startTime"], datetime):
payload["startTime"] = payload["startTime"].isoformat()
@ -1056,7 +1059,6 @@ async def update_database( # noqa: PLR0915
start_time=start_time,
end_time=end_time,
)
payload["spend"] = response_cost
prisma_client = _set_spend_logs_payload(
payload=payload,

View file

@ -1,9 +1,10 @@
import hashlib
import json
import secrets
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
from typing import List, Optional, cast
from typing import Any, List, Optional, cast
from pydantic import BaseModel
@ -69,6 +70,42 @@ def _get_spend_logs_metadata(
return clean_metadata
def generate_hash_from_response(response_obj: Any) -> str:
"""
Generate a stable hash from a response object.
Args:
response_obj: The response object to hash (can be dict, list, etc.)
Returns:
A hex string representation of the MD5 hash
"""
try:
# Create a stable JSON string of the entire response object
# Sort keys to ensure consistent ordering
json_str = json.dumps(response_obj, sort_keys=True)
# Generate a hash of the response object
unique_hash = hashlib.md5(json_str.encode()).hexdigest()
return unique_hash
except Exception:
# Return a fallback hash if serialization fails
return hashlib.md5(str(response_obj).encode()).hexdigest()
def get_spend_logs_id(
call_type: str, response_obj: dict, kwargs: dict
) -> Optional[str]:
if call_type == "aretrieve_batch":
# Generate a hash from the response object
id: Optional[str] = generate_hash_from_response(response_obj)
else:
id = cast(Optional[str], response_obj.get("id")) or cast(
Optional[str], kwargs.get("litellm_call_id")
)
return id
def get_logging_payload( # noqa: PLR0915
kwargs, response_obj, start_time, end_time
) -> SpendLogsPayload:
@ -94,7 +131,15 @@ def get_logging_payload( # noqa: PLR0915
usage = cast(dict, response_obj).get("usage", None) or {}
if isinstance(usage, litellm.Usage):
usage = dict(usage)
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
if isinstance(response_obj, dict):
response_obj_dict = response_obj
elif isinstance(response_obj, BaseModel):
response_obj_dict = response_obj.model_dump()
else:
response_obj_dict = {}
id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
standard_logging_payload = cast(
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
)
@ -177,14 +222,8 @@ def get_logging_payload( # noqa: PLR0915
endTime=_ensure_datetime_utc(end_time),
completionStartTime=_ensure_datetime_utc(completion_start_time),
model=kwargs.get("model", "") or "",
user=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_user_id", "")
or "",
team_id=kwargs.get("litellm_params", {})
.get("metadata", {})
.get("user_api_key_team_id", "")
or "",
user=metadata.get("user_api_key_user_id", "") or "",
team_id=metadata.get("user_api_key_team_id", "") or "",
metadata=json.dumps(clean_metadata),
cache_key=cache_key,
spend=kwargs.get("response_cost", 0),
@ -314,10 +353,13 @@ def _add_proxy_server_request_to_metadata(
Only store if _should_store_prompts_and_responses_in_spend_logs() is True
"""
if _should_store_prompts_and_responses_in_spend_logs():
_proxy_server_request = litellm_params.get("proxy_server_request", {})
_request_body = _proxy_server_request.get("body", {}) or {}
_request_body_json_str = json.dumps(_request_body, default=str)
metadata["proxy_server_request"] = _request_body_json_str
_proxy_server_request = cast(
Optional[dict], litellm_params.get("proxy_server_request", {})
)
if _proxy_server_request is not None:
_request_body = _proxy_server_request.get("body", {}) or {}
_request_body_json_str = json.dumps(_request_body, default=str)
metadata["proxy_server_request"] = _request_body_json_str
return metadata

View file

@ -24,6 +24,7 @@ from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .guardrails import GuardrailEventHooks
from .llms.openai import (
Batch,
ChatCompletionThinkingBlock,
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
@ -182,6 +183,8 @@ class CallTypes(Enum):
arealtime = "_arealtime"
create_batch = "create_batch"
acreate_batch = "acreate_batch"
aretrieve_batch = "aretrieve_batch"
retrieve_batch = "retrieve_batch"
pass_through = "pass_through_endpoint"
@ -1963,3 +1966,27 @@ class ProviderSpecificHeader(TypedDict):
class SelectTokenizerResponse(TypedDict):
type: Literal["openai_tokenizer", "huggingface_tokenizer"]
tokenizer: Any
class LiteLLMBatch(Batch):
_hidden_params: dict = {}
usage: Optional[Usage] = None
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except Exception:
# if using pydantic v1
return self.dict()

View file

@ -2982,3 +2982,70 @@ def test_json_valid_model_cost_map():
json.loads(json_str)
except json.JSONDecodeError as e:
assert False, f"Invalid JSON format: {str(e)}"
def test_batch_cost_calculator():
args = {
"completion_response": {
"choices": [
{
"content_filter_results": {
"hate": {"filtered": False, "severity": "safe"},
"protected_material_code": {
"filtered": False,
"detected": False,
},
"protected_material_text": {
"filtered": False,
"detected": False,
},
"self_harm": {"filtered": False, "severity": "safe"},
"sexual": {"filtered": False, "severity": "safe"},
"violence": {"filtered": False, "severity": "safe"},
},
"finish_reason": "stop",
"index": 0,
"logprobs": None,
"message": {
"content": 'As of my last update in October 2023, there are eight recognized planets in the solar system. They are:\n\n1. **Mercury** - The closest planet to the Sun, known for its extreme temperature fluctuations.\n2. **Venus** - Similar in size to Earth but with a thick atmosphere rich in carbon dioxide, leading to a greenhouse effect that makes it the hottest planet.\n3. **Earth** - The only planet known to support life, with a diverse environment and liquid water.\n4. **Mars** - Known as the Red Planet, it has the largest volcano and canyon in the solar system and features signs of past water.\n5. **Jupiter** - The largest planet in the solar system, known for its Great Red Spot and numerous moons.\n6. **Saturn** - Famous for its stunning rings, it is a gas giant also known for its extensive moon system.\n7. **Uranus** - An ice giant with a unique tilt, it rotates on its side and has a blue color due to methane in its atmosphere.\n8. **Neptune** - Another ice giant, known for its deep blue color and strong winds, it is the farthest planet from the Sun.\n\nPluto was previously classified as the ninth planet but was reclassified as a "dwarf planet" in 2006 by the International Astronomical Union.',
"refusal": None,
"role": "assistant",
},
}
],
"created": 1741135408,
"id": "chatcmpl-B7X96teepFM4ILP7cm4Ga62eRuV8p",
"model": "gpt-4o-mini-2024-07-18",
"object": "chat.completion",
"prompt_filter_results": [
{
"prompt_index": 0,
"content_filter_results": {
"hate": {"filtered": False, "severity": "safe"},
"jailbreak": {"filtered": False, "detected": False},
"self_harm": {"filtered": False, "severity": "safe"},
"sexual": {"filtered": False, "severity": "safe"},
"violence": {"filtered": False, "severity": "safe"},
},
}
],
"system_fingerprint": "fp_b705f0c291",
"usage": {
"completion_tokens": 278,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens": 20,
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
"total_tokens": 298,
},
},
"model": None,
}
cost = completion_cost(**args)
assert cost > 0