feat - add aretrieve_batch

This commit is contained in:
Ishaan Jaff 2024-05-28 17:12:41 -07:00
parent c580fe03a0
commit fe704e5857
3 changed files with 69 additions and 16 deletions

View file

@ -47,7 +47,7 @@ async def acreate_file(
**kwargs,
) -> Coroutine[Any, Any, FileObject]:
"""
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
@ -181,7 +181,7 @@ async def acreate_batch(
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Creates and executes a batch from an uploaded file of request
Async: Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
@ -311,6 +311,48 @@ def create_batch(
raise e
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Async: Retrieves a batch.
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
"""
try:
loop = asyncio.get_event_loop()
kwargs["aretrieve_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
retrieve_batch,
batch_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai"] = "openai",
@ -318,7 +360,7 @@ def retrieve_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Retrieves a batch.
@ -409,10 +451,6 @@ def list_batch():
pass
async def aretrieve_batch():
pass
async def acancel_batch():
pass

View file

@ -1672,6 +1672,14 @@ class OpenAIBatchesAPI(BaseLLM):
response = openai_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
response = await openai_client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
@ -1696,6 +1704,15 @@ class OpenAIBatchesAPI(BaseLLM):
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response

View file

@ -109,17 +109,15 @@ async def test_async_create_batch():
create_batch_response.input_file_id == batch_input_file_id
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
# time.sleep(30)
await asyncio.sleep(1)
# retrieved_batch = litellm.retrieve_batch(
# batch_id=create_batch_response.id, custom_llm_provider="openai"
# )
# print("retrieved batch=", retrieved_batch)
# # just assert that we retrieved a non None batch
retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id, custom_llm_provider="openai"
)
print("retrieved batch=", retrieved_batch)
# just assert that we retrieved a non None batch
# assert retrieved_batch.id == create_batch_response.id
pass
assert retrieved_batch.id == create_batch_response.id
def test_retrieve_batch():