diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index f24eda0432..a2e66cdd62 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -1,76 +1,16 @@ -import asyncio -import datetime import json -import threading -from typing import Any, List, Literal, Optional +from typing import Any, List, Literal, Tuple import litellm from litellm._logging import verbose_logger -from litellm.constants import ( - BATCH_STATUS_POLL_INTERVAL_SECONDS, - BATCH_STATUS_POLL_MAX_ATTEMPTS, -) -from litellm.files.main import afile_content -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.llms.openai import Batch -from litellm.types.utils import StandardLoggingPayload, Usage - - -async def batches_async_logging( - batch_id: str, - custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", - logging_obj: Optional[LiteLLMLoggingObj] = None, - **kwargs, -): - """ - Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens - - - Polls retrieve_batch until it returns a batch with status "completed" or "failed" - """ - from .main import aretrieve_batch - - verbose_logger.debug( - ".....in _batches_async_logging... polling retrieve to get batch status" - ) - if logging_obj is None: - raise ValueError( - "logging_obj is None cannot calculate cost / log batch creation event" - ) - for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS): - try: - start_time = datetime.datetime.now() - batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider) - verbose_logger.debug( - "in _batches_async_logging... batch status= %s", batch.status - ) - - if batch.status == "completed": - end_time = datetime.datetime.now() - await _handle_completed_batch( - batch=batch, - custom_llm_provider=custom_llm_provider, - logging_obj=logging_obj, - start_time=start_time, - end_time=end_time, - **kwargs, - ) - break - elif batch.status == "failed": - pass - except Exception as e: - verbose_logger.error("error in batches_async_logging", e) - await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS) +from litellm.types.utils import Usage async def _handle_completed_batch( batch: Batch, custom_llm_provider: Literal["openai", "azure", "vertex_ai"], - logging_obj: LiteLLMLoggingObj, - start_time: datetime.datetime, - end_time: datetime.datetime, - **kwargs, -) -> None: +) -> Tuple[float, Usage]: """Helper function to process a completed batch and handle logging""" # Get batch results file_content_dictionary = await _get_batch_output_file_content_as_dictionary( @@ -87,52 +27,7 @@ async def _handle_completed_batch( custom_llm_provider=custom_llm_provider, ) - # Handle logging - await _log_completed_batch( - logging_obj=logging_obj, - batch_usage=batch_usage, - batch_cost=batch_cost, - start_time=start_time, - end_time=end_time, - **kwargs, - ) - - -async def _log_completed_batch( - logging_obj: LiteLLMLoggingObj, - batch_usage: Usage, - batch_cost: float, - start_time: datetime.datetime, - end_time: datetime.datetime, - **kwargs, -) -> None: - """Helper function to handle all logging operations for a completed batch""" - logging_obj.call_type = "batch_success" - - standard_logging_object = _create_standard_logging_object_for_completed_batch( - kwargs=kwargs, - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - batch_usage_object=batch_usage, - response_cost=batch_cost, - ) - - logging_obj.model_call_details["standard_logging_object"] = standard_logging_object - - # Launch async and sync logging handlers - asyncio.create_task( - logging_obj.async_success_handler( - result=None, - start_time=start_time, - end_time=end_time, - cache_hit=None, - ) - ) - threading.Thread( - target=logging_obj.success_handler, - args=(None, start_time, end_time), - ).start() + return batch_cost, batch_usage async def _batch_cost_calculator( @@ -159,6 +54,8 @@ async def _get_batch_output_file_content_as_dictionary( """ Get the batch output file content as a list of dictionaries """ + from litellm.files.main import afile_content + if custom_llm_provider == "vertex_ai": raise ValueError("Vertex AI does not support file content retrieval") @@ -264,30 +161,3 @@ def _batch_response_was_successful(batch_job_output_file: dict) -> bool: """ _response: dict = batch_job_output_file.get("response", None) or {} return _response.get("status_code", None) == 200 - - -def _create_standard_logging_object_for_completed_batch( - kwargs: dict, - start_time: datetime.datetime, - end_time: datetime.datetime, - logging_obj: LiteLLMLoggingObj, - batch_usage_object: Usage, - response_cost: float, -) -> StandardLoggingPayload: - """ - Create a standard logging object for a completed batch - """ - standard_logging_object = logging_obj.model_call_details.get( - "standard_logging_object", None - ) - - if standard_logging_object is None: - raise ValueError("unable to create standard logging object for completed batch") - - # Add Completed Batch Job Usage and Response Cost - standard_logging_object["call_type"] = "batch_success" - standard_logging_object["response_cost"] = response_cost - standard_logging_object["total_tokens"] = batch_usage_object.total_tokens - standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens - standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens - return standard_logging_object diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 32428c9c18..2f4800043c 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -31,10 +31,9 @@ from litellm.types.llms.openai import ( RetrieveBatchRequest, ) from litellm.types.router import GenericLiteLLMParams +from litellm.types.utils import LiteLLMBatch from litellm.utils import client, get_litellm_params, supports_httpx_timeout -from .batch_utils import batches_async_logging - ####### ENVIRONMENT VARIABLES ################### openai_batches_instance = OpenAIBatchesAPI() azure_batches_instance = AzureBatchesAPI() @@ -85,17 +84,6 @@ async def acreate_batch( else: response = init_response - # Start async logging job - if response is not None: - asyncio.create_task( - batches_async_logging( - logging_obj=kwargs.get("litellm_logging_obj", None), - batch_id=response.id, - custom_llm_provider=custom_llm_provider, - **kwargs, - ) - ) - return response except Exception as e: raise e @@ -111,7 +99,7 @@ def create_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -) -> Union[Batch, Coroutine[Any, Any, Batch]]: +) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: """ Creates and executes a batch from an uploaded file of request @@ -119,21 +107,26 @@ def create_batch( """ try: optional_params = GenericLiteLLMParams(**kwargs) + litellm_call_id = kwargs.get("litellm_call_id", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) _is_async = kwargs.pop("acreate_batch", False) is True litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None) ### TIMEOUT LOGIC ### timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 - litellm_params = get_litellm_params( - custom_llm_provider=custom_llm_provider, - litellm_call_id=kwargs.get("litellm_call_id", None), - litellm_trace_id=kwargs.get("litellm_trace_id"), - litellm_metadata=kwargs.get("litellm_metadata"), - ) litellm_logging_obj.update_environment_variables( model=None, user=None, optional_params=optional_params.model_dump(), - litellm_params=litellm_params, + litellm_params={ + "litellm_call_id": litellm_call_id, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + **optional_params.model_dump(exclude_unset=True), + }, custom_llm_provider=custom_llm_provider, ) @@ -261,7 +254,7 @@ def create_batch( response=httpx.Response( status_code=400, content="Unsupported provider", - request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) return response @@ -269,6 +262,7 @@ def create_batch( raise e +@client async def aretrieve_batch( batch_id: str, custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", @@ -276,7 +270,7 @@ async def aretrieve_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -) -> Batch: +) -> LiteLLMBatch: """ Async: Retrieves a batch. @@ -310,6 +304,7 @@ async def aretrieve_batch( raise e +@client def retrieve_batch( batch_id: str, custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", @@ -317,7 +312,7 @@ def retrieve_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -) -> Union[Batch, Coroutine[Any, Any, Batch]]: +) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: """ Retrieves a batch. @@ -325,9 +320,23 @@ def retrieve_batch( """ try: optional_params = GenericLiteLLMParams(**kwargs) + + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None) ### TIMEOUT LOGIC ### timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default + litellm_params = get_litellm_params( + custom_llm_provider=custom_llm_provider, + litellm_call_id=kwargs.get("litellm_call_id", None), + litellm_trace_id=kwargs.get("litellm_trace_id"), + litellm_metadata=kwargs.get("litellm_metadata"), + ) + litellm_logging_obj.update_environment_variables( + model=None, + user=None, + optional_params=optional_params.model_dump(), + litellm_params=litellm_params, + custom_llm_provider=custom_llm_provider, + ) if ( timeout is not None diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 07676d8a83..757cf9e238 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -399,9 +399,12 @@ def _select_model_name_for_cost_calc( if base_model is not None: return_model = base_model - completion_response_model: Optional[str] = getattr( - completion_response, "model", None - ) + completion_response_model: Optional[str] = None + if completion_response is not None: + if isinstance(completion_response, BaseModel): + completion_response_model = getattr(completion_response, "model", None) + elif isinstance(completion_response, dict): + completion_response_model = completion_response.get("model", None) hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None) if completion_response_model is None and hidden_params is not None: if ( diff --git a/litellm/files/main.py b/litellm/files/main.py index 9f81b2e385..e49066e84b 100644 --- a/litellm/files/main.py +++ b/litellm/files/main.py @@ -816,7 +816,7 @@ def file_content( ) else: raise litellm.exceptions.BadRequestError( - message="LiteLLM doesn't support {} for 'file_content'. Only 'openai' and 'azure' are supported.".format( + message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format( custom_llm_provider ), model="n/a", diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index ceb150946c..e571e3f6c6 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -74,7 +74,16 @@ def get_litellm_metadata_from_kwargs(kwargs: dict): """ Helper to get litellm metadata from all litellm request kwargs """ - return kwargs.get("litellm_params", {}).get("metadata", {}) + litellm_params = kwargs.get("litellm_params", {}) + if litellm_params: + metadata = litellm_params.get("metadata", {}) + litellm_metadata = litellm_params.get("litellm_metadata", {}) + if litellm_metadata: + return litellm_metadata + elif metadata: + return metadata + + return {} # Helper functions used for OTEL logging diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d2fd5b6ca8..b92b0927d2 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -25,6 +25,7 @@ from litellm import ( turn_off_message_logging, ) from litellm._logging import _is_debugging_on, verbose_logger +from litellm.batches.batch_utils import _handle_completed_batch from litellm.caching.caching import DualCache, InMemoryCache from litellm.caching.caching_handler import LLMCachingHandler from litellm.cost_calculator import _select_model_name_for_cost_calc @@ -50,6 +51,7 @@ from litellm.types.utils import ( CallTypes, EmbeddingResponse, ImageResponse, + LiteLLMBatch, LiteLLMLoggingBaseClass, ModelResponse, ModelResponseStream, @@ -871,6 +873,24 @@ class Logging(LiteLLMLoggingBaseClass): return None + async def _response_cost_calculator_async( + self, + result: Union[ + ModelResponse, + ModelResponseStream, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, + HttpxBinaryResponseContent, + RerankResponse, + Batch, + FineTuningJob, + ], + cache_hit: Optional[bool] = None, + ) -> Optional[float]: + return self._response_cost_calculator(result=result, cache_hit=cache_hit) + def should_run_callback( self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str ) -> bool: @@ -928,8 +948,8 @@ class Logging(LiteLLMLoggingBaseClass): or isinstance(result, TextCompletionResponse) or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(result, RerankResponse) - or isinstance(result, Batch) or isinstance(result, FineTuningJob) + or isinstance(result, LiteLLMBatch) ): ## HIDDEN PARAMS ## hidden_params = getattr(result, "_hidden_params", {}) @@ -1525,6 +1545,19 @@ class Logging(LiteLLMLoggingBaseClass): print_verbose( "Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit) ) + + ## CALCULATE COST FOR BATCH JOBS + if self.call_type == CallTypes.aretrieve_batch.value and isinstance( + result, LiteLLMBatch + ): + + response_cost, batch_usage = await _handle_completed_batch( + batch=result, custom_llm_provider=self.custom_llm_provider + ) + + result._hidden_params["response_cost"] = response_cost + result.usage = batch_usage + start_time, end_time, result = self._success_handler_helper_fn( start_time=start_time, end_time=end_time, @@ -1532,6 +1565,7 @@ class Logging(LiteLLMLoggingBaseClass): cache_hit=cache_hit, standard_logging_object=kwargs.get("standard_logging_object", None), ) + ## BUILD COMPLETE STREAMED RESPONSE if "async_complete_streaming_response" in self.model_call_details: return # break out of this. diff --git a/litellm/llms/azure/batches/handler.py b/litellm/llms/azure/batches/handler.py index 5fae527670..d36ae648ab 100644 --- a/litellm/llms/azure/batches/handler.py +++ b/litellm/llms/azure/batches/handler.py @@ -2,7 +2,7 @@ Azure Batches API Handler """ -from typing import Any, Coroutine, Optional, Union +from typing import Any, Coroutine, Optional, Union, cast import httpx @@ -14,6 +14,7 @@ from litellm.types.llms.openai import ( CreateBatchRequest, RetrieveBatchRequest, ) +from litellm.types.utils import LiteLLMBatch class AzureBatchesAPI: @@ -64,9 +65,9 @@ class AzureBatchesAPI: self, create_batch_data: CreateBatchRequest, azure_client: AsyncAzureOpenAI, - ) -> Batch: + ) -> LiteLLMBatch: response = await azure_client.batches.create(**create_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) def create_batch( self, @@ -78,7 +79,7 @@ class AzureBatchesAPI: timeout: Union[float, httpx.Timeout], max_retries: Optional[int], client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, - ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( self.get_azure_openai_client( api_key=api_key, @@ -103,16 +104,16 @@ class AzureBatchesAPI: return self.acreate_batch( # type: ignore create_batch_data=create_batch_data, azure_client=azure_client ) - response = azure_client.batches.create(**create_batch_data) - return response + response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data) + return LiteLLMBatch(**response.model_dump()) async def aretrieve_batch( self, retrieve_batch_data: RetrieveBatchRequest, client: AsyncAzureOpenAI, - ) -> Batch: + ) -> LiteLLMBatch: response = await client.batches.retrieve(**retrieve_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) def retrieve_batch( self, @@ -149,8 +150,10 @@ class AzureBatchesAPI: return self.aretrieve_batch( # type: ignore retrieve_batch_data=retrieve_batch_data, client=azure_client ) - response = azure_client.batches.retrieve(**retrieve_batch_data) - return response + response = cast(AzureOpenAI, azure_client).batches.retrieve( + **retrieve_batch_data + ) + return LiteLLMBatch(**response.model_dump()) async def acancel_batch( self, diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 5465a24945..3fddca53e7 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -37,6 +37,7 @@ from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENT from litellm.types.utils import ( EmbeddingResponse, ImageResponse, + LiteLLMBatch, ModelResponse, ModelResponseStream, ) @@ -1755,9 +1756,9 @@ class OpenAIBatchesAPI(BaseLLM): self, create_batch_data: CreateBatchRequest, openai_client: AsyncOpenAI, - ) -> Batch: + ) -> LiteLLMBatch: response = await openai_client.batches.create(**create_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) def create_batch( self, @@ -1769,7 +1770,7 @@ class OpenAIBatchesAPI(BaseLLM): max_retries: Optional[int], organization: Optional[str], client: Optional[Union[OpenAI, AsyncOpenAI]] = None, - ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( api_key=api_key, api_base=api_base, @@ -1792,17 +1793,18 @@ class OpenAIBatchesAPI(BaseLLM): return self.acreate_batch( # type: ignore create_batch_data=create_batch_data, openai_client=openai_client ) - response = openai_client.batches.create(**create_batch_data) - return response + response = cast(OpenAI, openai_client).batches.create(**create_batch_data) + + return LiteLLMBatch(**response.model_dump()) async def aretrieve_batch( self, retrieve_batch_data: RetrieveBatchRequest, openai_client: AsyncOpenAI, - ) -> Batch: + ) -> LiteLLMBatch: verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data) response = await openai_client.batches.retrieve(**retrieve_batch_data) - return response + return LiteLLMBatch(**response.model_dump()) def retrieve_batch( self, @@ -1837,8 +1839,8 @@ class OpenAIBatchesAPI(BaseLLM): return self.aretrieve_batch( # type: ignore retrieve_batch_data=retrieve_batch_data, openai_client=openai_client ) - response = openai_client.batches.retrieve(**retrieve_batch_data) - return response + response = cast(OpenAI, openai_client).batches.retrieve(**retrieve_batch_data) + return LiteLLMBatch(**response.model_dump()) async def acancel_batch( self, diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 3d723d8ecf..b82268bef6 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -9,11 +9,12 @@ from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, ) from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM -from litellm.types.llms.openai import Batch, CreateBatchRequest +from litellm.types.llms.openai import CreateBatchRequest from litellm.types.llms.vertex_ai import ( VERTEX_CREDENTIALS_TYPES, VertexAIBatchPredictionJob, ) +from litellm.types.utils import LiteLLMBatch from .transformation import VertexAIBatchTransformation @@ -33,7 +34,7 @@ class VertexAIBatchPrediction(VertexLLM): vertex_location: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], - ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: sync_handler = _get_httpx_client() @@ -101,7 +102,7 @@ class VertexAIBatchPrediction(VertexLLM): vertex_batch_request: VertexAIBatchPredictionJob, api_base: str, headers: Dict[str, str], - ) -> Batch: + ) -> LiteLLMBatch: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.VERTEX_AI, ) @@ -138,7 +139,7 @@ class VertexAIBatchPrediction(VertexLLM): vertex_location: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], - ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: sync_handler = _get_httpx_client() access_token, project_id = self._ensure_access_token( @@ -199,7 +200,7 @@ class VertexAIBatchPrediction(VertexLLM): self, api_base: str, headers: Dict[str, str], - ) -> Batch: + ) -> LiteLLMBatch: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.VERTEX_AI, ) diff --git a/litellm/llms/vertex_ai/batches/transformation.py b/litellm/llms/vertex_ai/batches/transformation.py index 32cabdcf56..a97f312d48 100644 --- a/litellm/llms/vertex_ai/batches/transformation.py +++ b/litellm/llms/vertex_ai/batches/transformation.py @@ -4,8 +4,9 @@ from typing import Dict from litellm.llms.vertex_ai.common_utils import ( _convert_vertex_datetime_to_openai_datetime, ) -from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest +from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest from litellm.types.llms.vertex_ai import * +from litellm.types.utils import LiteLLMBatch class VertexAIBatchTransformation: @@ -47,8 +48,8 @@ class VertexAIBatchTransformation: @classmethod def transform_vertex_ai_batch_response_to_openai_batch_response( cls, response: VertexBatchPredictionResponse - ) -> Batch: - return Batch( + ) -> LiteLLMBatch: + return LiteLLMBatch( id=cls._get_batch_id_from_vertex_ai_batch_response(response), completion_window="24hrs", created_at=_convert_vertex_datetime_to_openai_datetime( diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 7c3caf74ff..8a46573955 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,9 +1,13 @@ model_list: - - model_name: my-langfuse-model + - model_name: openai/gpt-4o litellm_params: - model: langfuse/openai-model + model: openai/gpt-4o api_key: os.environ/OPENAI_API_KEY - - model_name: openai-model - litellm_params: - model: openai/gpt-3.5-turbo - api_key: os.environ/OPENAI_API_KEY \ No newline at end of file + +files_settings: + - custom_llm_provider: azure + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + +general_settings: + store_prompts_in_spend_logs: true \ No newline at end of file diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 118f9964e9..e00112b8d8 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -2,10 +2,10 @@ # /v1/batches Endpoints -import asyncio ###################################################################### -from typing import Dict, Optional +import asyncio +from typing import Dict, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response @@ -199,8 +199,11 @@ async def retrieve_batch( ``` """ from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, get_custom_headers, llm_router, + proxy_config, proxy_logging_obj, version, ) @@ -212,6 +215,23 @@ async def retrieve_batch( batch_id=batch_id, ) + data = cast(dict, _retrieve_batch_request) + + # setup logging + data["litellm_call_id"] = request.headers.get( + "x-litellm-call-id", str(uuid.uuid4()) + ) + + # Include original request and headers in the data + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + if litellm.enable_loadbalancing_on_batch_endpoints is True: if llm_router is None: raise HTTPException( @@ -221,7 +241,7 @@ async def retrieve_batch( }, ) - response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore + response = await llm_router.aretrieve_batch(**data) # type: ignore else: custom_llm_provider = ( provider @@ -229,7 +249,7 @@ async def retrieve_batch( or "openai" ) response = await litellm.aretrieve_batch( - custom_llm_provider=custom_llm_provider, **_retrieve_batch_request # type: ignore + custom_llm_provider=custom_llm_provider, **data # type: ignore ) ### ALERTING ### diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index 7a1d15fe01..0b34c575a1 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -11,6 +11,7 @@ import uuid from typing import TYPE_CHECKING, List, Optional, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse import litellm from litellm._logging import verbose_proxy_logger diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 04682a4d0b..ec8d6374ce 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -955,6 +955,9 @@ def _set_spend_logs_payload( prisma_client: PrismaClient, spend_logs_url: Optional[str] = None, ): + verbose_proxy_logger.info( + "Writing spend log to db - request_id: {}".format(payload.get("request_id")) + ) if prisma_client is not None and spend_logs_url is not None: if isinstance(payload["startTime"], datetime): payload["startTime"] = payload["startTime"].isoformat() @@ -1056,7 +1059,6 @@ async def update_database( # noqa: PLR0915 start_time=start_time, end_time=end_time, ) - payload["spend"] = response_cost prisma_client = _set_spend_logs_payload( payload=payload, diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index d3720ad75d..78304fa6b0 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -1,9 +1,10 @@ +import hashlib import json import secrets from datetime import datetime from datetime import datetime as dt from datetime import timezone -from typing import List, Optional, cast +from typing import Any, List, Optional, cast from pydantic import BaseModel @@ -69,6 +70,42 @@ def _get_spend_logs_metadata( return clean_metadata +def generate_hash_from_response(response_obj: Any) -> str: + """ + Generate a stable hash from a response object. + + Args: + response_obj: The response object to hash (can be dict, list, etc.) + + Returns: + A hex string representation of the MD5 hash + """ + try: + # Create a stable JSON string of the entire response object + # Sort keys to ensure consistent ordering + json_str = json.dumps(response_obj, sort_keys=True) + + # Generate a hash of the response object + unique_hash = hashlib.md5(json_str.encode()).hexdigest() + return unique_hash + except Exception: + # Return a fallback hash if serialization fails + return hashlib.md5(str(response_obj).encode()).hexdigest() + + +def get_spend_logs_id( + call_type: str, response_obj: dict, kwargs: dict +) -> Optional[str]: + if call_type == "aretrieve_batch": + # Generate a hash from the response object + id: Optional[str] = generate_hash_from_response(response_obj) + else: + id = cast(Optional[str], response_obj.get("id")) or cast( + Optional[str], kwargs.get("litellm_call_id") + ) + return id + + def get_logging_payload( # noqa: PLR0915 kwargs, response_obj, start_time, end_time ) -> SpendLogsPayload: @@ -94,7 +131,15 @@ def get_logging_payload( # noqa: PLR0915 usage = cast(dict, response_obj).get("usage", None) or {} if isinstance(usage, litellm.Usage): usage = dict(usage) - id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id") + + if isinstance(response_obj, dict): + response_obj_dict = response_obj + elif isinstance(response_obj, BaseModel): + response_obj_dict = response_obj.model_dump() + else: + response_obj_dict = {} + + id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs) standard_logging_payload = cast( Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None) ) @@ -177,14 +222,8 @@ def get_logging_payload( # noqa: PLR0915 endTime=_ensure_datetime_utc(end_time), completionStartTime=_ensure_datetime_utc(completion_start_time), model=kwargs.get("model", "") or "", - user=kwargs.get("litellm_params", {}) - .get("metadata", {}) - .get("user_api_key_user_id", "") - or "", - team_id=kwargs.get("litellm_params", {}) - .get("metadata", {}) - .get("user_api_key_team_id", "") - or "", + user=metadata.get("user_api_key_user_id", "") or "", + team_id=metadata.get("user_api_key_team_id", "") or "", metadata=json.dumps(clean_metadata), cache_key=cache_key, spend=kwargs.get("response_cost", 0), @@ -314,10 +353,13 @@ def _add_proxy_server_request_to_metadata( Only store if _should_store_prompts_and_responses_in_spend_logs() is True """ if _should_store_prompts_and_responses_in_spend_logs(): - _proxy_server_request = litellm_params.get("proxy_server_request", {}) - _request_body = _proxy_server_request.get("body", {}) or {} - _request_body_json_str = json.dumps(_request_body, default=str) - metadata["proxy_server_request"] = _request_body_json_str + _proxy_server_request = cast( + Optional[dict], litellm_params.get("proxy_server_request", {}) + ) + if _proxy_server_request is not None: + _request_body = _proxy_server_request.get("body", {}) or {} + _request_body_json_str = json.dumps(_request_body, default=str) + metadata["proxy_server_request"] = _request_body_json_str return metadata diff --git a/litellm/types/utils.py b/litellm/types/utils.py index dcaf5f35d1..7e6b95ab15 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -24,6 +24,7 @@ from typing_extensions import Callable, Dict, Required, TypedDict, override from ..litellm_core_utils.core_helpers import map_finish_reason from .guardrails import GuardrailEventHooks from .llms.openai import ( + Batch, ChatCompletionThinkingBlock, ChatCompletionToolCallChunk, ChatCompletionUsageBlock, @@ -182,6 +183,8 @@ class CallTypes(Enum): arealtime = "_arealtime" create_batch = "create_batch" acreate_batch = "acreate_batch" + aretrieve_batch = "aretrieve_batch" + retrieve_batch = "retrieve_batch" pass_through = "pass_through_endpoint" @@ -1963,3 +1966,27 @@ class ProviderSpecificHeader(TypedDict): class SelectTokenizerResponse(TypedDict): type: Literal["openai_tokenizer", "huggingface_tokenizer"] tokenizer: Any + + +class LiteLLMBatch(Batch): + _hidden_params: dict = {} + usage: Optional[Usage] = None + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def json(self, **kwargs): # type: ignore + try: + return self.model_dump() # noqa + except Exception: + # if using pydantic v1 + return self.dict() diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py index 200f2c012e..5de9216316 100644 --- a/tests/local_testing/test_completion_cost.py +++ b/tests/local_testing/test_completion_cost.py @@ -2982,3 +2982,70 @@ def test_json_valid_model_cost_map(): json.loads(json_str) except json.JSONDecodeError as e: assert False, f"Invalid JSON format: {str(e)}" + + +def test_batch_cost_calculator(): + + args = { + "completion_response": { + "choices": [ + { + "content_filter_results": { + "hate": {"filtered": False, "severity": "safe"}, + "protected_material_code": { + "filtered": False, + "detected": False, + }, + "protected_material_text": { + "filtered": False, + "detected": False, + }, + "self_harm": {"filtered": False, "severity": "safe"}, + "sexual": {"filtered": False, "severity": "safe"}, + "violence": {"filtered": False, "severity": "safe"}, + }, + "finish_reason": "stop", + "index": 0, + "logprobs": None, + "message": { + "content": 'As of my last update in October 2023, there are eight recognized planets in the solar system. They are:\n\n1. **Mercury** - The closest planet to the Sun, known for its extreme temperature fluctuations.\n2. **Venus** - Similar in size to Earth but with a thick atmosphere rich in carbon dioxide, leading to a greenhouse effect that makes it the hottest planet.\n3. **Earth** - The only planet known to support life, with a diverse environment and liquid water.\n4. **Mars** - Known as the Red Planet, it has the largest volcano and canyon in the solar system and features signs of past water.\n5. **Jupiter** - The largest planet in the solar system, known for its Great Red Spot and numerous moons.\n6. **Saturn** - Famous for its stunning rings, it is a gas giant also known for its extensive moon system.\n7. **Uranus** - An ice giant with a unique tilt, it rotates on its side and has a blue color due to methane in its atmosphere.\n8. **Neptune** - Another ice giant, known for its deep blue color and strong winds, it is the farthest planet from the Sun.\n\nPluto was previously classified as the ninth planet but was reclassified as a "dwarf planet" in 2006 by the International Astronomical Union.', + "refusal": None, + "role": "assistant", + }, + } + ], + "created": 1741135408, + "id": "chatcmpl-B7X96teepFM4ILP7cm4Ga62eRuV8p", + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "prompt_filter_results": [ + { + "prompt_index": 0, + "content_filter_results": { + "hate": {"filtered": False, "severity": "safe"}, + "jailbreak": {"filtered": False, "detected": False}, + "self_harm": {"filtered": False, "severity": "safe"}, + "sexual": {"filtered": False, "severity": "safe"}, + "violence": {"filtered": False, "severity": "safe"}, + }, + } + ], + "system_fingerprint": "fp_b705f0c291", + "usage": { + "completion_tokens": 278, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens": 20, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + "total_tokens": 298, + }, + }, + "model": None, + } + + cost = completion_cost(**args) + assert cost > 0