feat - add aretrieve_batch

This commit is contained in:
Ishaan Jaff 2024-05-28 17:12:41 -07:00
parent c580fe03a0
commit fe704e5857
3 changed files with 69 additions and 16 deletions

View file

@ -47,7 +47,7 @@ async def acreate_file(
**kwargs,
) -> Coroutine[Any, Any, FileObject]:
"""
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
@ -181,7 +181,7 @@ async def acreate_batch(
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Creates and executes a batch from an uploaded file of request
Async: Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
@ -311,6 +311,48 @@ def create_batch(
raise e
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Async: Retrieves a batch.
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
"""
try:
loop = asyncio.get_event_loop()
kwargs["aretrieve_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
retrieve_batch,
batch_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai"] = "openai",
@ -318,7 +360,7 @@ def retrieve_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Retrieves a batch.
@ -409,10 +451,6 @@ def list_batch():
pass
async def aretrieve_batch():
pass
async def acancel_batch():
pass