feat - add aretrieve_batch

This commit is contained in:
Ishaan Jaff 2024-05-28 17:12:41 -07:00
parent 1ef7cd923c
commit 6688215c18
3 changed files with 69 additions and 16 deletions

View file

@ -1672,6 +1672,14 @@ class OpenAIBatchesAPI(BaseLLM):
response = openai_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
response = await openai_client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
@ -1696,6 +1704,15 @@ class OpenAIBatchesAPI(BaseLLM):
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, openai_client=openai_client
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response