feat - add acreate_batch

This commit is contained in:
Ishaan Jaff 2024-05-28 17:03:29 -07:00
parent 758ed9e923
commit 1ef7cd923c
3 changed files with 143 additions and 53 deletions

View file

@ -51,30 +51,33 @@ async def acreate_file(
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
try:
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_file,
file,
purpose,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Use a partial function to pass your keyword arguments
func = partial(
create_file,
file,
purpose,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
return response
except Exception as e:
raise e
def create_file(
@ -167,6 +170,52 @@ def create_file(
raise e
async def acreate_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_batch,
completion_window,
endpoint,
input_file_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
@ -176,7 +225,7 @@ def create_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Creates and executes a batch from an uploaded file of request
@ -224,6 +273,8 @@ def create_batch(
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acreate_batch", False) is True
_create_batch_request = CreateBatchRequest(
completion_window=completion_window,
endpoint=endpoint,
@ -240,6 +291,7 @@ def create_batch(
create_batch_data=_create_batch_request,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
else:
raise litellm.exceptions.BadRequestError(
@ -320,7 +372,10 @@ def retrieve_batch(
extra_body=extra_body,
)
_is_async = kwargs.pop("aretrieve_batch", False) is True
response = openai_batches_instance.retrieve_batch(
_is_async=_is_async,
retrieve_batch_data=_retrieve_batch_request,
api_base=api_base,
api_key=api_key,
@ -354,11 +409,6 @@ def list_batch():
pass
# Async Functions
async def acreate_batch():
pass
async def aretrieve_batch():
pass

View file

@ -1605,47 +1605,76 @@ class OpenAIBatchesAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
response = await openai_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> Batch:
openai_client: OpenAI = self.get_openai_client(
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, openai_client=openai_client
)
response = openai_client.batches.create(**create_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
@ -1654,19 +1683,25 @@ class OpenAIBatchesAPI(BaseLLM):
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
@ -1675,14 +1710,19 @@ class OpenAIBatchesAPI(BaseLLM):
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response

View file

@ -89,25 +89,25 @@ async def test_async_create_batch():
batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
# create_batch_response = litellm.create_batch(
# completion_window="24h",
# endpoint="/v1/chat/completions",
# input_file_id=batch_input_file_id,
# custom_llm_provider="openai",
# metadata={"key1": "value1", "key2": "value2"},
# )
create_batch_response = await litellm.acreate_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider="openai",
metadata={"key1": "value1", "key2": "value2"},
)
# print("response from litellm.create_batch=", create_batch_response)
print("response from litellm.create_batch=", create_batch_response)
# assert (
# create_batch_response.id is not None
# ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
# assert (
# create_batch_response.endpoint == "/v1/chat/completions"
# ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
# assert (
# create_batch_response.input_file_id == batch_input_file_id
# ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
assert (
create_batch_response.id is not None
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
assert (
create_batch_response.endpoint == "/v1/chat/completions"
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
assert (
create_batch_response.input_file_id == batch_input_file_id
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
# time.sleep(30)