diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py index 0f68193695..f24eda0432 100644 --- a/litellm/batches/batch_utils.py +++ b/litellm/batches/batch_utils.py @@ -277,17 +277,8 @@ def _create_standard_logging_object_for_completed_batch( """ Create a standard logging object for a completed batch """ - from litellm.litellm_core_utils.litellm_logging import ( - get_standard_logging_object_payload, - ) - - standard_logging_object = get_standard_logging_object_payload( - kwargs=kwargs, - init_response_obj=None, - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - status="success", + standard_logging_object = logging_obj.model_call_details.get( + "standard_logging_object", None ) if standard_logging_object is None: diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 815ddb2ed0..1c50474ccb 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -19,13 +19,14 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union import httpx import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.azure.azure import AzureBatchesAPI from litellm.llms.openai.openai import OpenAIBatchesAPI from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest from litellm.types.router import GenericLiteLLMParams -from litellm.utils import client, supports_httpx_timeout +from litellm.utils import client, get_litellm_params, supports_httpx_timeout from .batch_utils import batches_async_logging @@ -114,9 +115,22 @@ def create_batch( try: optional_params = GenericLiteLLMParams(**kwargs) _is_async = kwargs.pop("acreate_batch", False) is True + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None) ### TIMEOUT LOGIC ### timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default + litellm_params = get_litellm_params( + custom_llm_provider=custom_llm_provider, + litellm_call_id=kwargs.get("litellm_call_id", None), + litellm_trace_id=kwargs.get("litellm_trace_id"), + litellm_metadata=kwargs.get("litellm_metadata"), + ) + litellm_logging_obj.update_environment_variables( + model=None, + user=None, + optional_params=optional_params.model_dump(), + litellm_params=litellm_params, + custom_llm_provider=custom_llm_provider, + ) if ( timeout is not None diff --git a/litellm/constants.py b/litellm/constants.py index 9fddc38e53..8ecd9154e7 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -93,5 +93,5 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [ "optimize-prompt/", ] -BATCH_STATUS_POLL_INTERVAL_SECONDS = 10 -BATCH_STATUS_POLL_MAX_ATTEMPTS = 10 +BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour +BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index dc558e427a..6065779d58 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -388,10 +388,16 @@ class Logging(LiteLLMLoggingBaseClass): return standard_callback_dynamic_params def update_environment_variables( - self, model, user, optional_params, litellm_params, **additional_params + self, + litellm_params: Dict, + optional_params: Dict, + model: Optional[str] = None, + user: Optional[str] = None, + **additional_params, ): self.optional_params = optional_params - self.model = model + if model is not None: + self.model = model self.user = user self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params) self.logger_fn = litellm_params.get("logger_fn", None) @@ -2885,9 +2891,11 @@ def get_standard_logging_object_payload( litellm_params = kwargs.get("litellm_params", {}) proxy_server_request = litellm_params.get("proxy_server_request") or {} end_user_id = proxy_server_request.get("body", {}).get("user", None) - metadata = ( - litellm_params.get("metadata", {}) or {} - ) # if litellm_params['metadata'] == None + metadata: dict = ( + litellm_params.get("litellm_metadata") + or litellm_params.get("metadata", None) + or {} + ) completion_start_time = kwargs.get("completion_start_time", end_time) call_type = kwargs.get("call_type") cache_hit = kwargs.get("cache_hit", False) diff --git a/litellm/main.py b/litellm/main.py index 77f4af9f2b..8e68ba7a33 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1140,6 +1140,7 @@ def completion( # type: ignore # noqa: PLR0915 litellm_trace_id=kwargs.get("litellm_trace_id"), hf_model_name=hf_model_name, custom_prompt_dict=custom_prompt_dict, + litellm_metadata=kwargs.get("litellm_metadata"), ) logging.update_environment_variables( model=model, diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 315109280c..8ec05ddadf 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -138,7 +138,7 @@ def rerank( # noqa: PLR0915 litellm_logging_obj.update_environment_variables( model=model, user=user, - optional_params=optional_rerank_params, + optional_params=dict(optional_rerank_params), litellm_params={ "litellm_call_id": litellm_call_id, "proxy_server_request": proxy_server_request, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index f0ac2fcab2..7230b4a9b6 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1592,6 +1592,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False): all_litellm_params = [ "metadata", + "litellm_metadata", "litellm_trace_id", "tags", "acompletion", diff --git a/litellm/utils.py b/litellm/utils.py index df704eed05..82bae30933 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1203,11 +1203,18 @@ def client(original_function): # noqa: PLR0915 return wrapper -def _is_async_request(kwargs: Optional[dict]) -> bool: +def _is_async_request( + kwargs: Optional[dict], + is_pass_through: bool = False, +) -> bool: """ Returns True if the call type is an internal async request. eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime + + Args: + kwargs (dict): The kwargs passed to the litellm function + is_pass_through (bool): Whether the call is a pass-through call. By default all pass through calls are async. """ if kwargs is None: return False @@ -1221,6 +1228,7 @@ def _is_async_request(kwargs: Optional[dict]) -> bool: or kwargs.get("arerank", False) is True or kwargs.get("_arealtime", False) is True or kwargs.get("acreate_batch", False) is True + or is_pass_through is True ): return True return False @@ -1927,7 +1935,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915 def get_litellm_params( - api_key=None, + api_key: Optional[str] = None, force_timeout=600, azure=False, logger_fn=None, @@ -1935,12 +1943,12 @@ def get_litellm_params( hugging_face=False, replicate=False, together_ai=False, - custom_llm_provider=None, - api_base=None, + custom_llm_provider: Optional[str] = None, + api_base: Optional[str] = None, litellm_call_id=None, model_alias_map=None, completion_call_id=None, - metadata=None, + metadata: Optional[dict] = None, model_info=None, proxy_server_request=None, acompletion=None, @@ -1954,10 +1962,11 @@ def get_litellm_params( text_completion=None, azure_ad_token_provider=None, user_continue_message=None, - base_model=None, - litellm_trace_id=None, + base_model: Optional[str] = None, + litellm_trace_id: Optional[str] = None, hf_model_name: Optional[str] = None, custom_prompt_dict: Optional[dict] = None, + litellm_metadata: Optional[dict] = None, ): litellm_params = { "acompletion": acompletion, @@ -1989,8 +1998,8 @@ def get_litellm_params( "litellm_trace_id": litellm_trace_id, "hf_model_name": hf_model_name, "custom_prompt_dict": custom_prompt_dict, + "litellm_metadata": litellm_metadata, } - return litellm_params diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py index d01ab735ef..78867458a6 100644 --- a/tests/batches_tests/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -12,6 +12,7 @@ load_dotenv() sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system-path + import logging import time @@ -191,12 +192,18 @@ async def test_async_create_batch(provider): batch_input_file_id is not None ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" + extra_metadata_field = { + "user_api_key_alias": "special_api_key_alias", + "user_api_key_team_alias": "special_team_alias", + } create_batch_response = await litellm.acreate_batch( completion_window="24h", endpoint="/v1/chat/completions", input_file_id=batch_input_file_id, custom_llm_provider=provider, metadata={"key1": "value1", "key2": "value2"}, + # litellm specific param - used for logging metadata on logging callback + litellm_metadata=extra_metadata_field, ) print("response from litellm.create_batch=", create_batch_response) @@ -215,6 +222,18 @@ async def test_async_create_batch(provider): await asyncio.sleep(6) # Assert that the create batch event is logged on CustomLogger assert custom_logger.standard_logging_object is not None + print( + "standard_logging_object=", + json.dumps(custom_logger.standard_logging_object, indent=4, default=str), + ) + assert ( + custom_logger.standard_logging_object["metadata"]["user_api_key_alias"] + == extra_metadata_field["user_api_key_alias"] + ) + assert ( + custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"] + == extra_metadata_field["user_api_key_team_alias"] + ) retrieved_batch = await litellm.aretrieve_batch( batch_id=create_batch_response.id, custom_llm_provider=provider