feat - add acreate_batch

This commit is contained in:
Ishaan Jaff 2024-05-28 17:03:29 -07:00
parent 758ed9e923
commit 1ef7cd923c
3 changed files with 143 additions and 53 deletions

View file

@ -51,30 +51,33 @@ async def acreate_file(
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
try:
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_file,
file,
purpose,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Use a partial function to pass your keyword arguments
func = partial(
create_file,
file,
purpose,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
return response
except Exception as e:
raise e
def create_file(
@ -167,6 +170,52 @@ def create_file(
raise e
async def acreate_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_batch,
completion_window,
endpoint,
input_file_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
@ -176,7 +225,7 @@ def create_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Creates and executes a batch from an uploaded file of request
@ -224,6 +273,8 @@ def create_batch(
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acreate_batch", False) is True
_create_batch_request = CreateBatchRequest(
completion_window=completion_window,
endpoint=endpoint,
@ -240,6 +291,7 @@ def create_batch(
create_batch_data=_create_batch_request,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
else:
raise litellm.exceptions.BadRequestError(
@ -320,7 +372,10 @@ def retrieve_batch(
extra_body=extra_body,
)
_is_async = kwargs.pop("aretrieve_batch", False) is True
response = openai_batches_instance.retrieve_batch(
_is_async=_is_async,
retrieve_batch_data=_retrieve_batch_request,
api_base=api_base,
api_key=api_key,
@ -354,11 +409,6 @@ def list_batch():
pass
# Async Functions
async def acreate_batch():
pass
async def aretrieve_batch():
pass