diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 05a6dfd517..119043e4c1 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -47,7 +47,7 @@ async def acreate_file( **kwargs, ) -> Coroutine[Any, Any, FileObject]: """ - Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API. + Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API. LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files """ @@ -181,7 +181,7 @@ async def acreate_batch( **kwargs, ) -> Coroutine[Any, Any, Batch]: """ - Creates and executes a batch from an uploaded file of request + Async: Creates and executes a batch from an uploaded file of request LiteLLM Equivalent of POST: https://api.openai.com/v1/batches """ @@ -311,6 +311,48 @@ def create_batch( raise e +async def aretrieve_batch( + batch_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Coroutine[Any, Any, Batch]: + """ + Async: Retrieves a batch. + + LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id} + """ + try: + loop = asyncio.get_event_loop() + kwargs["aretrieve_batch"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + retrieve_batch, + batch_id, + custom_llm_provider, + metadata, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + return response + except Exception as e: + raise e + + def retrieve_batch( batch_id: str, custom_llm_provider: Literal["openai"] = "openai", @@ -318,7 +360,7 @@ def retrieve_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -): +) -> Union[Batch, Coroutine[Any, Any, Batch]]: """ Retrieves a batch. @@ -409,10 +451,6 @@ def list_batch(): pass -async def aretrieve_batch(): - pass - - async def acancel_batch(): pass diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index fa1f13c70a..43d088f0db 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1672,6 +1672,14 @@ class OpenAIBatchesAPI(BaseLLM): response = openai_client.batches.create(**create_batch_data) return response + async def aretrieve_batch( + self, + retrieve_batch_data: RetrieveBatchRequest, + openai_client: AsyncOpenAI, + ) -> Batch: + response = await openai_client.batches.retrieve(**retrieve_batch_data) + return response + def retrieve_batch( self, _is_async: bool, @@ -1696,6 +1704,15 @@ class OpenAIBatchesAPI(BaseLLM): raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_batch( # type: ignore + retrieve_batch_data=retrieve_batch_data, openai_client=openai_client + ) response = openai_client.batches.retrieve(**retrieve_batch_data) return response diff --git a/litellm/tests/test_openai_batches.py b/litellm/tests/test_openai_batches.py index 497662006d..2bf0090128 100644 --- a/litellm/tests/test_openai_batches.py +++ b/litellm/tests/test_openai_batches.py @@ -109,17 +109,15 @@ async def test_async_create_batch(): create_batch_response.input_file_id == batch_input_file_id ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" - # time.sleep(30) + await asyncio.sleep(1) - # retrieved_batch = litellm.retrieve_batch( - # batch_id=create_batch_response.id, custom_llm_provider="openai" - # ) - # print("retrieved batch=", retrieved_batch) - # # just assert that we retrieved a non None batch + retrieved_batch = await litellm.aretrieve_batch( + batch_id=create_batch_response.id, custom_llm_provider="openai" + ) + print("retrieved batch=", retrieved_batch) + # just assert that we retrieved a non None batch - # assert retrieved_batch.id == create_batch_response.id - - pass + assert retrieved_batch.id == create_batch_response.id def test_retrieve_batch():