mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Litellm dev 03 04 2025 p3 (#8997)
* fix(core_helpers.py): handle litellm_metadata instead of 'metadata' * feat(batches/): ensure batches logs are written to db makes batches response dict compatible * fix(cost_calculator.py): handle batch response being a dictionary * fix(batches/main.py): modify retrieve endpoints to use @client decorator enables logging to work on retrieve call * fix(batches/main.py): fix retrieve batch response type to be 'dict' compatible * fix(spend_tracking_utils.py): send unique uuid for retrieve batch call type create batch and retrieve batch share the same id * fix(spend_tracking_utils.py): prevent duplicate retrieve batch calls from being double counted * refactor(batches/): refactor cost tracking for batches - do it on retrieve, and within the established litellm_logging pipeline ensures cost is always logged to db * fix: fix linting errors * fix: fix linting error
This commit is contained in:
parent
f2a9d67e05
commit
b43b8dc21c
17 changed files with 314 additions and 219 deletions
|
@ -1,76 +1,16 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, List, Literal, Optional
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import (
|
||||
BATCH_STATUS_POLL_INTERVAL_SECONDS,
|
||||
BATCH_STATUS_POLL_MAX_ATTEMPTS,
|
||||
)
|
||||
from litellm.files.main import afile_content
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import StandardLoggingPayload, Usage
|
||||
|
||||
|
||||
async def batches_async_logging(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens
|
||||
|
||||
|
||||
Polls retrieve_batch until it returns a batch with status "completed" or "failed"
|
||||
"""
|
||||
from .main import aretrieve_batch
|
||||
|
||||
verbose_logger.debug(
|
||||
".....in _batches_async_logging... polling retrieve to get batch status"
|
||||
)
|
||||
if logging_obj is None:
|
||||
raise ValueError(
|
||||
"logging_obj is None cannot calculate cost / log batch creation event"
|
||||
)
|
||||
for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS):
|
||||
try:
|
||||
start_time = datetime.datetime.now()
|
||||
batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider)
|
||||
verbose_logger.debug(
|
||||
"in _batches_async_logging... batch status= %s", batch.status
|
||||
)
|
||||
|
||||
if batch.status == "completed":
|
||||
end_time = datetime.datetime.now()
|
||||
await _handle_completed_batch(
|
||||
batch=batch,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
elif batch.status == "failed":
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_logger.error("error in batches_async_logging", e)
|
||||
await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS)
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
|
||||
async def _handle_completed_batch(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
) -> Tuple[float, Usage]:
|
||||
"""Helper function to process a completed batch and handle logging"""
|
||||
# Get batch results
|
||||
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||
|
@ -87,52 +27,7 @@ async def _handle_completed_batch(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Handle logging
|
||||
await _log_completed_batch(
|
||||
logging_obj=logging_obj,
|
||||
batch_usage=batch_usage,
|
||||
batch_cost=batch_cost,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _log_completed_batch(
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
batch_usage: Usage,
|
||||
batch_cost: float,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Helper function to handle all logging operations for a completed batch"""
|
||||
logging_obj.call_type = "batch_success"
|
||||
|
||||
standard_logging_object = _create_standard_logging_object_for_completed_batch(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
batch_usage_object=batch_usage,
|
||||
response_cost=batch_cost,
|
||||
)
|
||||
|
||||
logging_obj.model_call_details["standard_logging_object"] = standard_logging_object
|
||||
|
||||
# Launch async and sync logging handlers
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
result=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(None, start_time, end_time),
|
||||
).start()
|
||||
return batch_cost, batch_usage
|
||||
|
||||
|
||||
async def _batch_cost_calculator(
|
||||
|
@ -159,6 +54,8 @@ async def _get_batch_output_file_content_as_dictionary(
|
|||
"""
|
||||
Get the batch output file content as a list of dictionaries
|
||||
"""
|
||||
from litellm.files.main import afile_content
|
||||
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
|
||||
|
@ -264,30 +161,3 @@ def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
|||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
return _response.get("status_code", None) == 200
|
||||
|
||||
|
||||
def _create_standard_logging_object_for_completed_batch(
|
||||
kwargs: dict,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
batch_usage_object: Usage,
|
||||
response_cost: float,
|
||||
) -> StandardLoggingPayload:
|
||||
"""
|
||||
Create a standard logging object for a completed batch
|
||||
"""
|
||||
standard_logging_object = logging_obj.model_call_details.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("unable to create standard logging object for completed batch")
|
||||
|
||||
# Add Completed Batch Job Usage and Response Cost
|
||||
standard_logging_object["call_type"] = "batch_success"
|
||||
standard_logging_object["response_cost"] = response_cost
|
||||
standard_logging_object["total_tokens"] = batch_usage_object.total_tokens
|
||||
standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens
|
||||
standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens
|
||||
return standard_logging_object
|
||||
|
|
|
@ -31,10 +31,9 @@ from litellm.types.llms.openai import (
|
|||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
|
||||
|
||||
from .batch_utils import batches_async_logging
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
azure_batches_instance = AzureBatchesAPI()
|
||||
|
@ -85,17 +84,6 @@ async def acreate_batch(
|
|||
else:
|
||||
response = init_response
|
||||
|
||||
# Start async logging job
|
||||
if response is not None:
|
||||
asyncio.create_task(
|
||||
batches_async_logging(
|
||||
logging_obj=kwargs.get("litellm_logging_obj", None),
|
||||
batch_id=response.id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -111,7 +99,7 @@ def create_batch(
|
|||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
"""
|
||||
Creates and executes a batch from an uploaded file of request
|
||||
|
||||
|
@ -119,21 +107,26 @@ def create_batch(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_call_id=kwargs.get("litellm_call_id", None),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
user=None,
|
||||
optional_params=optional_params.model_dump(),
|
||||
litellm_params=litellm_params,
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
"model_info": model_info,
|
||||
"metadata": metadata,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
**optional_params.model_dump(exclude_unset=True),
|
||||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
|
@ -261,7 +254,7 @@ def create_batch(
|
|||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
@ -269,6 +262,7 @@ def create_batch(
|
|||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def aretrieve_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
|
@ -276,7 +270,7 @@ async def aretrieve_batch(
|
|||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Async: Retrieves a batch.
|
||||
|
||||
|
@ -310,6 +304,7 @@ async def aretrieve_batch(
|
|||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def retrieve_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
|
@ -317,7 +312,7 @@ def retrieve_batch(
|
|||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
"""
|
||||
Retrieves a batch.
|
||||
|
||||
|
@ -325,9 +320,23 @@ def retrieve_batch(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_call_id=kwargs.get("litellm_call_id", None),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
user=None,
|
||||
optional_params=optional_params.model_dump(),
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
|
|
|
@ -399,9 +399,12 @@ def _select_model_name_for_cost_calc(
|
|||
if base_model is not None:
|
||||
return_model = base_model
|
||||
|
||||
completion_response_model: Optional[str] = getattr(
|
||||
completion_response, "model", None
|
||||
)
|
||||
completion_response_model: Optional[str] = None
|
||||
if completion_response is not None:
|
||||
if isinstance(completion_response, BaseModel):
|
||||
completion_response_model = getattr(completion_response, "model", None)
|
||||
elif isinstance(completion_response, dict):
|
||||
completion_response_model = completion_response.get("model", None)
|
||||
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
||||
if completion_response_model is None and hidden_params is not None:
|
||||
if (
|
||||
|
|
|
@ -816,7 +816,7 @@ def file_content(
|
|||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_content'. Only 'openai' and 'azure' are supported.".format(
|
||||
message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
|
|
|
@ -74,7 +74,16 @@ def get_litellm_metadata_from_kwargs(kwargs: dict):
|
|||
"""
|
||||
Helper to get litellm metadata from all litellm request kwargs
|
||||
"""
|
||||
return kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if litellm_params:
|
||||
metadata = litellm_params.get("metadata", {})
|
||||
litellm_metadata = litellm_params.get("litellm_metadata", {})
|
||||
if litellm_metadata:
|
||||
return litellm_metadata
|
||||
elif metadata:
|
||||
return metadata
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
# Helper functions used for OTEL logging
|
||||
|
|
|
@ -25,6 +25,7 @@ from litellm import (
|
|||
turn_off_message_logging,
|
||||
)
|
||||
from litellm._logging import _is_debugging_on, verbose_logger
|
||||
from litellm.batches.batch_utils import _handle_completed_batch
|
||||
from litellm.caching.caching import DualCache, InMemoryCache
|
||||
from litellm.caching.caching_handler import LLMCachingHandler
|
||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||
|
@ -50,6 +51,7 @@ from litellm.types.utils import (
|
|||
CallTypes,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
LiteLLMLoggingBaseClass,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
|
@ -871,6 +873,24 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
return None
|
||||
|
||||
async def _response_cost_calculator_async(
|
||||
self,
|
||||
result: Union[
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
RerankResponse,
|
||||
Batch,
|
||||
FineTuningJob,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
) -> Optional[float]:
|
||||
return self._response_cost_calculator(result=result, cache_hit=cache_hit)
|
||||
|
||||
def should_run_callback(
|
||||
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
|
||||
) -> bool:
|
||||
|
@ -928,8 +948,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
or isinstance(result, TextCompletionResponse)
|
||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(result, RerankResponse)
|
||||
or isinstance(result, Batch)
|
||||
or isinstance(result, FineTuningJob)
|
||||
or isinstance(result, LiteLLMBatch)
|
||||
):
|
||||
## HIDDEN PARAMS ##
|
||||
hidden_params = getattr(result, "_hidden_params", {})
|
||||
|
@ -1525,6 +1545,19 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
print_verbose(
|
||||
"Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit)
|
||||
)
|
||||
|
||||
## CALCULATE COST FOR BATCH JOBS
|
||||
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
|
||||
result, LiteLLMBatch
|
||||
):
|
||||
|
||||
response_cost, batch_usage = await _handle_completed_batch(
|
||||
batch=result, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
|
||||
result._hidden_params["response_cost"] = response_cost
|
||||
result.usage = batch_usage
|
||||
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -1532,6 +1565,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
cache_hit=cache_hit,
|
||||
standard_logging_object=kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
if "async_complete_streaming_response" in self.model_call_details:
|
||||
return # break out of this.
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Azure Batches API Handler
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
from typing import Any, Coroutine, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -14,6 +14,7 @@ from litellm.types.llms.openai import (
|
|||
CreateBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
|
||||
class AzureBatchesAPI:
|
||||
|
@ -64,9 +65,9 @@ class AzureBatchesAPI:
|
|||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
azure_client: AsyncAzureOpenAI,
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
response = await azure_client.batches.create(**create_batch_data)
|
||||
return response
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
|
@ -78,7 +79,7 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
|
@ -103,16 +104,16 @@ class AzureBatchesAPI:
|
|||
return self.acreate_batch( # type: ignore
|
||||
create_batch_data=create_batch_data, azure_client=azure_client
|
||||
)
|
||||
response = azure_client.batches.create(**create_batch_data)
|
||||
return response
|
||||
response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def aretrieve_batch(
|
||||
self,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
client: AsyncAzureOpenAI,
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
response = await client.batches.retrieve(**retrieve_batch_data)
|
||||
return response
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
|
@ -149,8 +150,10 @@ class AzureBatchesAPI:
|
|||
return self.aretrieve_batch( # type: ignore
|
||||
retrieve_batch_data=retrieve_batch_data, client=azure_client
|
||||
)
|
||||
response = azure_client.batches.retrieve(**retrieve_batch_data)
|
||||
return response
|
||||
response = cast(AzureOpenAI, azure_client).batches.retrieve(
|
||||
**retrieve_batch_data
|
||||
)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def acancel_batch(
|
||||
self,
|
||||
|
|
|
@ -37,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENT
|
|||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
)
|
||||
|
@ -1755,9 +1756,9 @@ class OpenAIBatchesAPI(BaseLLM):
|
|||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
openai_client: AsyncOpenAI,
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
response = await openai_client.batches.create(**create_batch_data)
|
||||
return response
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
|
@ -1769,7 +1770,7 @@ class OpenAIBatchesAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1792,17 +1793,18 @@ class OpenAIBatchesAPI(BaseLLM):
|
|||
return self.acreate_batch( # type: ignore
|
||||
create_batch_data=create_batch_data, openai_client=openai_client
|
||||
)
|
||||
response = openai_client.batches.create(**create_batch_data)
|
||||
return response
|
||||
response = cast(OpenAI, openai_client).batches.create(**create_batch_data)
|
||||
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def aretrieve_batch(
|
||||
self,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
openai_client: AsyncOpenAI,
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data)
|
||||
response = await openai_client.batches.retrieve(**retrieve_batch_data)
|
||||
return response
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
|
@ -1837,8 +1839,8 @@ class OpenAIBatchesAPI(BaseLLM):
|
|||
return self.aretrieve_batch( # type: ignore
|
||||
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
|
||||
)
|
||||
response = openai_client.batches.retrieve(**retrieve_batch_data)
|
||||
return response
|
||||
response = cast(OpenAI, openai_client).batches.retrieve(**retrieve_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def acancel_batch(
|
||||
self,
|
||||
|
|
|
@ -9,11 +9,12 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.llms.openai import Batch, CreateBatchRequest
|
||||
from litellm.types.llms.openai import CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VERTEX_CREDENTIALS_TYPES,
|
||||
VertexAIBatchPredictionJob,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from .transformation import VertexAIBatchTransformation
|
||||
|
||||
|
@ -33,7 +34,7 @@ class VertexAIBatchPrediction(VertexLLM):
|
|||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
|
@ -101,7 +102,7 @@ class VertexAIBatchPrediction(VertexLLM):
|
|||
vertex_batch_request: VertexAIBatchPredictionJob,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
|
@ -138,7 +139,7 @@ class VertexAIBatchPrediction(VertexLLM):
|
|||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
|
@ -199,7 +200,7 @@ class VertexAIBatchPrediction(VertexLLM):
|
|||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> Batch:
|
||||
) -> LiteLLMBatch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
|
|
|
@ -4,8 +4,9 @@ from typing import Dict
|
|||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
|
||||
from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
|
||||
class VertexAIBatchTransformation:
|
||||
|
@ -47,8 +48,8 @@ class VertexAIBatchTransformation:
|
|||
@classmethod
|
||||
def transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> Batch:
|
||||
return Batch(
|
||||
) -> LiteLLMBatch:
|
||||
return LiteLLMBatch(
|
||||
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
|
||||
completion_window="24hrs",
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
model_list:
|
||||
- model_name: my-langfuse-model
|
||||
- model_name: openai/gpt-4o
|
||||
litellm_params:
|
||||
model: langfuse/openai-model
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai-model
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
||||
general_settings:
|
||||
store_prompts_in_spend_logs: true
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
# /v1/batches Endpoints
|
||||
|
||||
import asyncio
|
||||
|
||||
######################################################################
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
|
||||
|
@ -199,8 +199,11 @@ async def retrieve_batch(
|
|||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
@ -212,6 +215,23 @@ async def retrieve_batch(
|
|||
batch_id=batch_id,
|
||||
)
|
||||
|
||||
data = cast(dict, _retrieve_batch_request)
|
||||
|
||||
# setup logging
|
||||
data["litellm_call_id"] = request.headers.get(
|
||||
"x-litellm-call-id", str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
|
@ -221,7 +241,7 @@ async def retrieve_batch(
|
|||
},
|
||||
)
|
||||
|
||||
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
|
||||
response = await llm_router.aretrieve_batch(**data) # type: ignore
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
|
@ -229,7 +249,7 @@ async def retrieve_batch(
|
|||
or "openai"
|
||||
)
|
||||
response = await litellm.aretrieve_batch(
|
||||
custom_llm_provider=custom_llm_provider, **_retrieve_batch_request # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, **data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
|
|
@ -11,6 +11,7 @@ import uuid
|
|||
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
|
|
@ -955,6 +955,9 @@ def _set_spend_logs_payload(
|
|||
prisma_client: PrismaClient,
|
||||
spend_logs_url: Optional[str] = None,
|
||||
):
|
||||
verbose_proxy_logger.info(
|
||||
"Writing spend log to db - request_id: {}".format(payload.get("request_id"))
|
||||
)
|
||||
if prisma_client is not None and spend_logs_url is not None:
|
||||
if isinstance(payload["startTime"], datetime):
|
||||
payload["startTime"] = payload["startTime"].isoformat()
|
||||
|
@ -1056,7 +1059,6 @@ async def update_database( # noqa: PLR0915
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
payload["spend"] = response_cost
|
||||
prisma_client = _set_spend_logs_payload(
|
||||
payload=payload,
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from datetime import datetime as dt
|
||||
from datetime import timezone
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -69,6 +70,42 @@ def _get_spend_logs_metadata(
|
|||
return clean_metadata
|
||||
|
||||
|
||||
def generate_hash_from_response(response_obj: Any) -> str:
|
||||
"""
|
||||
Generate a stable hash from a response object.
|
||||
|
||||
Args:
|
||||
response_obj: The response object to hash (can be dict, list, etc.)
|
||||
|
||||
Returns:
|
||||
A hex string representation of the MD5 hash
|
||||
"""
|
||||
try:
|
||||
# Create a stable JSON string of the entire response object
|
||||
# Sort keys to ensure consistent ordering
|
||||
json_str = json.dumps(response_obj, sort_keys=True)
|
||||
|
||||
# Generate a hash of the response object
|
||||
unique_hash = hashlib.md5(json_str.encode()).hexdigest()
|
||||
return unique_hash
|
||||
except Exception:
|
||||
# Return a fallback hash if serialization fails
|
||||
return hashlib.md5(str(response_obj).encode()).hexdigest()
|
||||
|
||||
|
||||
def get_spend_logs_id(
|
||||
call_type: str, response_obj: dict, kwargs: dict
|
||||
) -> Optional[str]:
|
||||
if call_type == "aretrieve_batch":
|
||||
# Generate a hash from the response object
|
||||
id: Optional[str] = generate_hash_from_response(response_obj)
|
||||
else:
|
||||
id = cast(Optional[str], response_obj.get("id")) or cast(
|
||||
Optional[str], kwargs.get("litellm_call_id")
|
||||
)
|
||||
return id
|
||||
|
||||
|
||||
def get_logging_payload( # noqa: PLR0915
|
||||
kwargs, response_obj, start_time, end_time
|
||||
) -> SpendLogsPayload:
|
||||
|
@ -94,7 +131,15 @@ def get_logging_payload( # noqa: PLR0915
|
|||
usage = cast(dict, response_obj).get("usage", None) or {}
|
||||
if isinstance(usage, litellm.Usage):
|
||||
usage = dict(usage)
|
||||
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
|
||||
|
||||
if isinstance(response_obj, dict):
|
||||
response_obj_dict = response_obj
|
||||
elif isinstance(response_obj, BaseModel):
|
||||
response_obj_dict = response_obj.model_dump()
|
||||
else:
|
||||
response_obj_dict = {}
|
||||
|
||||
id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
|
||||
standard_logging_payload = cast(
|
||||
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
|
||||
)
|
||||
|
@ -177,14 +222,8 @@ def get_logging_payload( # noqa: PLR0915
|
|||
endTime=_ensure_datetime_utc(end_time),
|
||||
completionStartTime=_ensure_datetime_utc(completion_start_time),
|
||||
model=kwargs.get("model", "") or "",
|
||||
user=kwargs.get("litellm_params", {})
|
||||
.get("metadata", {})
|
||||
.get("user_api_key_user_id", "")
|
||||
or "",
|
||||
team_id=kwargs.get("litellm_params", {})
|
||||
.get("metadata", {})
|
||||
.get("user_api_key_team_id", "")
|
||||
or "",
|
||||
user=metadata.get("user_api_key_user_id", "") or "",
|
||||
team_id=metadata.get("user_api_key_team_id", "") or "",
|
||||
metadata=json.dumps(clean_metadata),
|
||||
cache_key=cache_key,
|
||||
spend=kwargs.get("response_cost", 0),
|
||||
|
@ -314,10 +353,13 @@ def _add_proxy_server_request_to_metadata(
|
|||
Only store if _should_store_prompts_and_responses_in_spend_logs() is True
|
||||
"""
|
||||
if _should_store_prompts_and_responses_in_spend_logs():
|
||||
_proxy_server_request = litellm_params.get("proxy_server_request", {})
|
||||
_request_body = _proxy_server_request.get("body", {}) or {}
|
||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||
metadata["proxy_server_request"] = _request_body_json_str
|
||||
_proxy_server_request = cast(
|
||||
Optional[dict], litellm_params.get("proxy_server_request", {})
|
||||
)
|
||||
if _proxy_server_request is not None:
|
||||
_request_body = _proxy_server_request.get("body", {}) or {}
|
||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||
metadata["proxy_server_request"] = _request_body_json_str
|
||||
return metadata
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from typing_extensions import Callable, Dict, Required, TypedDict, override
|
|||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .guardrails import GuardrailEventHooks
|
||||
from .llms.openai import (
|
||||
Batch,
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
|
@ -182,6 +183,8 @@ class CallTypes(Enum):
|
|||
arealtime = "_arealtime"
|
||||
create_batch = "create_batch"
|
||||
acreate_batch = "acreate_batch"
|
||||
aretrieve_batch = "aretrieve_batch"
|
||||
retrieve_batch = "retrieve_batch"
|
||||
pass_through = "pass_through_endpoint"
|
||||
|
||||
|
||||
|
@ -1963,3 +1966,27 @@ class ProviderSpecificHeader(TypedDict):
|
|||
class SelectTokenizerResponse(TypedDict):
|
||||
type: Literal["openai_tokenizer", "huggingface_tokenizer"]
|
||||
tokenizer: Any
|
||||
|
||||
|
||||
class LiteLLMBatch(Batch):
|
||||
_hidden_params: dict = {}
|
||||
usage: Optional[Usage] = None
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def json(self, **kwargs): # type: ignore
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
|
|
@ -2982,3 +2982,70 @@ def test_json_valid_model_cost_map():
|
|||
json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
assert False, f"Invalid JSON format: {str(e)}"
|
||||
|
||||
|
||||
def test_batch_cost_calculator():
|
||||
|
||||
args = {
|
||||
"completion_response": {
|
||||
"choices": [
|
||||
{
|
||||
"content_filter_results": {
|
||||
"hate": {"filtered": False, "severity": "safe"},
|
||||
"protected_material_code": {
|
||||
"filtered": False,
|
||||
"detected": False,
|
||||
},
|
||||
"protected_material_text": {
|
||||
"filtered": False,
|
||||
"detected": False,
|
||||
},
|
||||
"self_harm": {"filtered": False, "severity": "safe"},
|
||||
"sexual": {"filtered": False, "severity": "safe"},
|
||||
"violence": {"filtered": False, "severity": "safe"},
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"message": {
|
||||
"content": 'As of my last update in October 2023, there are eight recognized planets in the solar system. They are:\n\n1. **Mercury** - The closest planet to the Sun, known for its extreme temperature fluctuations.\n2. **Venus** - Similar in size to Earth but with a thick atmosphere rich in carbon dioxide, leading to a greenhouse effect that makes it the hottest planet.\n3. **Earth** - The only planet known to support life, with a diverse environment and liquid water.\n4. **Mars** - Known as the Red Planet, it has the largest volcano and canyon in the solar system and features signs of past water.\n5. **Jupiter** - The largest planet in the solar system, known for its Great Red Spot and numerous moons.\n6. **Saturn** - Famous for its stunning rings, it is a gas giant also known for its extensive moon system.\n7. **Uranus** - An ice giant with a unique tilt, it rotates on its side and has a blue color due to methane in its atmosphere.\n8. **Neptune** - Another ice giant, known for its deep blue color and strong winds, it is the farthest planet from the Sun.\n\nPluto was previously classified as the ninth planet but was reclassified as a "dwarf planet" in 2006 by the International Astronomical Union.',
|
||||
"refusal": None,
|
||||
"role": "assistant",
|
||||
},
|
||||
}
|
||||
],
|
||||
"created": 1741135408,
|
||||
"id": "chatcmpl-B7X96teepFM4ILP7cm4Ga62eRuV8p",
|
||||
"model": "gpt-4o-mini-2024-07-18",
|
||||
"object": "chat.completion",
|
||||
"prompt_filter_results": [
|
||||
{
|
||||
"prompt_index": 0,
|
||||
"content_filter_results": {
|
||||
"hate": {"filtered": False, "severity": "safe"},
|
||||
"jailbreak": {"filtered": False, "detected": False},
|
||||
"self_harm": {"filtered": False, "severity": "safe"},
|
||||
"sexual": {"filtered": False, "severity": "safe"},
|
||||
"violence": {"filtered": False, "severity": "safe"},
|
||||
},
|
||||
}
|
||||
],
|
||||
"system_fingerprint": "fp_b705f0c291",
|
||||
"usage": {
|
||||
"completion_tokens": 278,
|
||||
"completion_tokens_details": {
|
||||
"accepted_prediction_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
"prompt_tokens": 20,
|
||||
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
|
||||
"total_tokens": 298,
|
||||
},
|
||||
},
|
||||
"model": None,
|
||||
}
|
||||
|
||||
cost = completion_cost(**args)
|
||||
assert cost > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue