diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 056318c8dd..05a6dfd517 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -51,30 +51,33 @@ async def acreate_file( LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files """ - loop = asyncio.get_event_loop() - kwargs["acreate_file"] = True + try: + loop = asyncio.get_event_loop() + kwargs["acreate_file"] = True - # Use a partial function to pass your keyword arguments - func = partial( - create_file, - file, - purpose, - custom_llm_provider, - extra_headers, - extra_body, - **kwargs, - ) + # Use a partial function to pass your keyword arguments + func = partial( + create_file, + file, + purpose, + custom_llm_provider, + extra_headers, + extra_body, + **kwargs, + ) - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) - init_response = await loop.run_in_executor(None, func_with_context) - if asyncio.iscoroutine(init_response): - response = await init_response - else: - response = init_response # type: ignore + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore - return response + return response + except Exception as e: + raise e def create_file( @@ -167,6 +170,52 @@ def create_file( raise e +async def acreate_batch( + completion_window: Literal["24h"], + endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], + input_file_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Coroutine[Any, Any, Batch]: + """ + Creates and executes a batch from an uploaded file of request + + LiteLLM Equivalent of POST: https://api.openai.com/v1/batches + """ + try: + loop = asyncio.get_event_loop() + kwargs["acreate_batch"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + create_batch, + completion_window, + endpoint, + input_file_id, + custom_llm_provider, + metadata, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + return response + except Exception as e: + raise e + + def create_batch( completion_window: Literal["24h"], endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], @@ -176,7 +225,7 @@ def create_batch( extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, -) -> Batch: +) -> Union[Batch, Coroutine[Any, Any, Batch]]: """ Creates and executes a batch from an uploaded file of request @@ -224,6 +273,8 @@ def create_batch( elif timeout is None: timeout = 600.0 + _is_async = kwargs.pop("acreate_batch", False) is True + _create_batch_request = CreateBatchRequest( completion_window=completion_window, endpoint=endpoint, @@ -240,6 +291,7 @@ def create_batch( create_batch_data=_create_batch_request, timeout=timeout, max_retries=optional_params.max_retries, + _is_async=_is_async, ) else: raise litellm.exceptions.BadRequestError( @@ -320,7 +372,10 @@ def retrieve_batch( extra_body=extra_body, ) + _is_async = kwargs.pop("aretrieve_batch", False) is True + response = openai_batches_instance.retrieve_batch( + _is_async=_is_async, retrieve_batch_data=_retrieve_batch_request, api_base=api_base, api_key=api_key, @@ -354,11 +409,6 @@ def list_batch(): pass -# Async Functions -async def acreate_batch(): - pass - - async def aretrieve_batch(): pass diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 05fc5784b6..fa1f13c70a 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1605,47 +1605,76 @@ class OpenAIBatchesAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI] = None, - ) -> OpenAI: + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + _is_async: bool = False, + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: received_args = locals() + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None if client is None: data = {} for k, v in received_args.items(): - if k == "self" or k == "client": + if k == "self" or k == "client" or k == "_is_async": pass elif k == "api_base" and v is not None: data["base_url"] = v elif v is not None: data[k] = v - openai_client = OpenAI(**data) # type: ignore + if _is_async is True: + openai_client = AsyncOpenAI(**data) + else: + openai_client = OpenAI(**data) # type: ignore else: openai_client = client return openai_client + async def acreate_batch( + self, + create_batch_data: CreateBatchRequest, + openai_client: AsyncOpenAI, + ) -> Batch: + response = await openai_client.batches.create(**create_batch_data) + return response + def create_batch( self, + _is_async: bool, create_batch_data: CreateBatchRequest, api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI] = None, - ) -> Batch: - openai_client: OpenAI = self.get_openai_client( + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( api_key=api_key, api_base=api_base, timeout=timeout, max_retries=max_retries, organization=organization, client=client, + _is_async=_is_async, ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_batch( # type: ignore + create_batch_data=create_batch_data, openai_client=openai_client + ) response = openai_client.batches.create(**create_batch_data) return response def retrieve_batch( self, + _is_async: bool, retrieve_batch_data: RetrieveBatchRequest, api_key: Optional[str], api_base: Optional[str], @@ -1654,19 +1683,25 @@ class OpenAIBatchesAPI(BaseLLM): organization: Optional[str], client: Optional[OpenAI] = None, ): - openai_client: OpenAI = self.get_openai_client( + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( api_key=api_key, api_base=api_base, timeout=timeout, max_retries=max_retries, organization=organization, client=client, + _is_async=_is_async, ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) response = openai_client.batches.retrieve(**retrieve_batch_data) return response def cancel_batch( self, + _is_async: bool, cancel_batch_data: CancelBatchRequest, api_key: Optional[str], api_base: Optional[str], @@ -1675,14 +1710,19 @@ class OpenAIBatchesAPI(BaseLLM): organization: Optional[str], client: Optional[OpenAI] = None, ): - openai_client: OpenAI = self.get_openai_client( + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( api_key=api_key, api_base=api_base, timeout=timeout, max_retries=max_retries, organization=organization, client=client, + _is_async=_is_async, ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) response = openai_client.batches.cancel(**cancel_batch_data) return response diff --git a/litellm/tests/test_openai_batches.py b/litellm/tests/test_openai_batches.py index 2de417619b..497662006d 100644 --- a/litellm/tests/test_openai_batches.py +++ b/litellm/tests/test_openai_batches.py @@ -89,25 +89,25 @@ async def test_async_create_batch(): batch_input_file_id is not None ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" - # create_batch_response = litellm.create_batch( - # completion_window="24h", - # endpoint="/v1/chat/completions", - # input_file_id=batch_input_file_id, - # custom_llm_provider="openai", - # metadata={"key1": "value1", "key2": "value2"}, - # ) + create_batch_response = await litellm.acreate_batch( + completion_window="24h", + endpoint="/v1/chat/completions", + input_file_id=batch_input_file_id, + custom_llm_provider="openai", + metadata={"key1": "value1", "key2": "value2"}, + ) - # print("response from litellm.create_batch=", create_batch_response) + print("response from litellm.create_batch=", create_batch_response) - # assert ( - # create_batch_response.id is not None - # ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" - # assert ( - # create_batch_response.endpoint == "/v1/chat/completions" - # ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" - # assert ( - # create_batch_response.input_file_id == batch_input_file_id - # ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" + assert ( + create_batch_response.id is not None + ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" + assert ( + create_batch_response.endpoint == "/v1/chat/completions" + ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" + assert ( + create_batch_response.input_file_id == batch_input_file_id + ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" # time.sleep(30)