diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 71c2d3b5cf..0d9ba55587 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -25,7 +25,7 @@ from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest from litellm.types.router import GenericLiteLLMParams -from litellm.utils import supports_httpx_timeout +from litellm.utils import client, supports_httpx_timeout ####### ENVIRONMENT VARIABLES ################### openai_batches_instance = OpenAIBatchesAPI() @@ -34,6 +34,7 @@ vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="") ################################################# +@client async def acreate_batch( completion_window: Literal["24h"], endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], @@ -80,6 +81,7 @@ async def acreate_batch( raise e +@client def create_batch( completion_window: Literal["24h"], endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index de53b6e486..f460cf757e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -34,7 +34,11 @@ from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_custom_logger, redact_message_input_output_from_logging, ) -from litellm.types.llms.openai import AllMessageValues, HttpxBinaryResponseContent +from litellm.types.llms.openai import ( + AllMessageValues, + Batch, + HttpxBinaryResponseContent, +) from litellm.types.rerank import RerankResponse from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( @@ -749,6 +753,7 @@ class Logging(LiteLLMLoggingBaseClass): TextCompletionResponse, HttpxBinaryResponseContent, RerankResponse, + Batch, ], cache_hit: Optional[bool] = None, ) -> Optional[float]: @@ -865,6 +870,7 @@ class Logging(LiteLLMLoggingBaseClass): or isinstance(result, TextCompletionResponse) or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(result, RerankResponse) + or isinstance(result, Batch) ): ## RESPONSE COST ## self.model_call_details["response_cost"] = ( diff --git a/litellm/types/utils.py b/litellm/types/utils.py index c6e4b2767b..8176d9a50c 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -169,6 +169,8 @@ class CallTypes(Enum): rerank = "rerank" arerank = "arerank" arealtime = "_arealtime" + create_batch = "create_batch" + acreate_batch = "acreate_batch" pass_through = "pass_through_endpoint" @@ -190,6 +192,9 @@ CallTypesLiteral = Literal[ "rerank", "arerank", "_arealtime", + "create_batch", + "acreate_batch", + "pass_through_endpoint", ] diff --git a/litellm/utils.py b/litellm/utils.py index bcc69b0021..df704eed05 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -712,16 +712,8 @@ def client(original_function): # noqa: PLR0915 def wrapper(*args, **kwargs): # noqa: PLR0915 # DO NOT MOVE THIS. It always needs to run first # Check if this is an async function. If so only execute the async function - if ( - kwargs.get("acompletion", False) is True - or kwargs.get("aembedding", False) is True - or kwargs.get("aimg_generation", False) is True - or kwargs.get("amoderation", False) is True - or kwargs.get("atext_completion", False) is True - or kwargs.get("atranscription", False) is True - or kwargs.get("arerank", False) is True - or kwargs.get("_arealtime", False) is True - ): + call_type = original_function.__name__ + if _is_async_request(kwargs): # [OPTIONAL] CHECK MAX RETRIES / REQUEST if litellm.num_retries_per_request is not None: # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] @@ -759,20 +751,10 @@ def client(original_function): # noqa: PLR0915 ) # only set litellm_call_id if its not in kwargs - call_type = original_function.__name__ if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) - model: Optional[str] = None - try: - model = args[0] if len(args) > 0 else kwargs["model"] - except Exception: - model = None - if ( - call_type != CallTypes.image_generation.value - and call_type != CallTypes.text_completion.value - ): - raise ValueError("model param not passed in.") + model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None) try: if logging_obj is None: @@ -1022,16 +1004,7 @@ def client(original_function): # noqa: PLR0915 if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) - model = "" - try: - model = args[0] if len(args) > 0 else kwargs["model"] - except Exception: - if ( - call_type != CallTypes.aimage_generation.value # model optional - and call_type != CallTypes.atext_completion.value # can also be engine - and call_type != CallTypes.amoderation.value - ): - raise ValueError("model param not passed in.") + model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None) try: if logging_obj is None: @@ -1054,7 +1027,7 @@ def client(original_function): # noqa: PLR0915 ) _caching_handler_response: CachingHandlerResponse = ( await _llm_caching_handler._async_get_cache( - model=model, + model=model or "", original_function=original_function, logging_obj=logging_obj, start_time=start_time, @@ -1100,7 +1073,7 @@ def client(original_function): # noqa: PLR0915 "id", None ) result._hidden_params["api_base"] = get_api_base( - model=model, + model=model or "", optional_params=kwargs, ) result._hidden_params["response_cost"] = ( @@ -1230,6 +1203,29 @@ def client(original_function): # noqa: PLR0915 return wrapper +def _is_async_request(kwargs: Optional[dict]) -> bool: + """ + Returns True if the call type is an internal async request. + + eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime + """ + if kwargs is None: + return False + if ( + kwargs.get("acompletion", False) is True + or kwargs.get("aembedding", False) is True + or kwargs.get("aimg_generation", False) is True + or kwargs.get("amoderation", False) is True + or kwargs.get("atext_completion", False) is True + or kwargs.get("atranscription", False) is True + or kwargs.get("arerank", False) is True + or kwargs.get("_arealtime", False) is True + or kwargs.get("acreate_batch", False) is True + ): + return True + return False + + @lru_cache(maxsize=128) def _select_tokenizer( model: str, diff --git a/tests/local_testing/test_openai_batches_and_files.py b/tests/local_testing/test_openai_batches_and_files.py index 7e0c711bdf..9c8ab79269 100644 --- a/tests/local_testing/test_openai_batches_and_files.py +++ b/tests/local_testing/test_openai_batches_and_files.py @@ -16,7 +16,7 @@ import logging import time import pytest - +from typing import Optional import litellm from litellm import create_batch, create_file from litellm._logging import verbose_logger @@ -24,14 +24,36 @@ from test_gcs_bucket import load_vertex_ai_credentials verbose_logger.setLevel(logging.DEBUG) +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import StandardLoggingPayload + + +class TestCustomLogger(CustomLogger): + def __init__(self): + super().__init__() + self.standard_logging_object: Optional[StandardLoggingPayload] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print( + "Success event logged with kwargs=", + kwargs, + "and response_obj=", + response_obj, + ) + self.standard_logging_object = kwargs["standard_logging_object"] + @pytest.mark.parametrize("provider", ["openai"]) # , "azure" -def test_create_batch(provider): +@pytest.mark.asyncio +async def test_create_batch(provider): """ 1. Create File for Batch completion 2. Create Batch Request 3. Retrieve the specific batch """ + custom_logger = TestCustomLogger() + litellm.callbacks = [custom_logger, "datadog"] + if provider == "azure": # Don't have anymore Azure Quota return @@ -39,7 +61,7 @@ def test_create_batch(provider): _current_dir = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(_current_dir, file_name) - file_obj = litellm.create_file( + file_obj = await litellm.acreate_file( file=open(file_path, "rb"), purpose="batch", custom_llm_provider=provider, @@ -51,8 +73,8 @@ def test_create_batch(provider): batch_input_file_id is not None ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" - time.sleep(5) - create_batch_response = litellm.create_batch( + await asyncio.sleep(1) + create_batch_response = await litellm.acreate_batch( completion_window="24h", endpoint="/v1/chat/completions", input_file_id=batch_input_file_id, @@ -61,7 +83,14 @@ def test_create_batch(provider): ) print("response from litellm.create_batch=", create_batch_response) + await asyncio.sleep(6) + # Assert that the create batch event is logged on CustomLogger + assert custom_logger.standard_logging_object is not None + print( + "standard_logging_object=", + json.dumps(custom_logger.standard_logging_object, indent=4), + ) assert ( create_batch_response.id is not None ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" @@ -73,7 +102,7 @@ def test_create_batch(provider): create_batch_response.input_file_id == batch_input_file_id ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" - retrieved_batch = litellm.retrieve_batch( + retrieved_batch = await litellm.aretrieve_batch( batch_id=create_batch_response.id, custom_llm_provider=provider ) print("retrieved batch=", retrieved_batch) @@ -82,10 +111,10 @@ def test_create_batch(provider): assert retrieved_batch.id == create_batch_response.id # list all batches - list_batches = litellm.list_batches(custom_llm_provider=provider, limit=2) + list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2) print("list_batches=", list_batches) - file_content = litellm.file_content( + file_content = await litellm.afile_content( file_id=batch_input_file_id, custom_llm_provider=provider )