feat - add acreate_batch

This commit is contained in:
Ishaan Jaff 2024-05-28 17:03:29 -07:00
parent 758ed9e923
commit 1ef7cd923c
3 changed files with 143 additions and 53 deletions

View file

@ -51,30 +51,33 @@ async def acreate_file(
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
""" """
loop = asyncio.get_event_loop() try:
kwargs["acreate_file"] = True loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial( func = partial(
create_file, create_file,
file, file,
purpose, purpose,
custom_llm_provider, custom_llm_provider,
extra_headers, extra_headers,
extra_body, extra_body,
**kwargs, **kwargs,
) )
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response): if asyncio.iscoroutine(init_response):
response = await init_response response = await init_response
else: else:
response = init_response # type: ignore response = init_response # type: ignore
return response return response
except Exception as e:
raise e
def create_file( def create_file(
@ -167,6 +170,52 @@ def create_file(
raise e raise e
async def acreate_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, Batch]:
"""
Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_batch,
completion_window,
endpoint,
input_file_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def create_batch( def create_batch(
completion_window: Literal["24h"], completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
@ -176,7 +225,7 @@ def create_batch(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> Batch: ) -> Union[Batch, Coroutine[Any, Any, Batch]]:
""" """
Creates and executes a batch from an uploaded file of request Creates and executes a batch from an uploaded file of request
@ -224,6 +273,8 @@ def create_batch(
elif timeout is None: elif timeout is None:
timeout = 600.0 timeout = 600.0
_is_async = kwargs.pop("acreate_batch", False) is True
_create_batch_request = CreateBatchRequest( _create_batch_request = CreateBatchRequest(
completion_window=completion_window, completion_window=completion_window,
endpoint=endpoint, endpoint=endpoint,
@ -240,6 +291,7 @@ def create_batch(
create_batch_data=_create_batch_request, create_batch_data=_create_batch_request,
timeout=timeout, timeout=timeout,
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
_is_async=_is_async,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -320,7 +372,10 @@ def retrieve_batch(
extra_body=extra_body, extra_body=extra_body,
) )
_is_async = kwargs.pop("aretrieve_batch", False) is True
response = openai_batches_instance.retrieve_batch( response = openai_batches_instance.retrieve_batch(
_is_async=_is_async,
retrieve_batch_data=_retrieve_batch_request, retrieve_batch_data=_retrieve_batch_request,
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
@ -354,11 +409,6 @@ def list_batch():
pass pass
# Async Functions
async def acreate_batch():
pass
async def aretrieve_batch(): async def aretrieve_batch():
pass pass

View file

@ -1605,47 +1605,76 @@ class OpenAIBatchesAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> OpenAI: _is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
received_args = locals() received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
if client is None: if client is None:
data = {} data = {}
for k, v in received_args.items(): for k, v in received_args.items():
if k == "self" or k == "client": if k == "self" or k == "client" or k == "_is_async":
pass pass
elif k == "api_base" and v is not None: elif k == "api_base" and v is not None:
data["base_url"] = v data["base_url"] = v
elif v is not None: elif v is not None:
data[k] = v data[k] = v
openai_client = OpenAI(**data) # type: ignore if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else: else:
openai_client = client openai_client = client
return openai_client return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
response = await openai_client.batches.create(**create_batch_data)
return response
def create_batch( def create_batch(
self, self,
_is_async: bool,
create_batch_data: CreateBatchRequest, create_batch_data: CreateBatchRequest,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Batch: ) -> Union[Batch, Coroutine[Any, Any, Batch]]:
openai_client: OpenAI = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client, client=client,
_is_async=_is_async,
) )
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, openai_client=openai_client
)
response = openai_client.batches.create(**create_batch_data) response = openai_client.batches.create(**create_batch_data)
return response return response
def retrieve_batch( def retrieve_batch(
self, self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest, retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
@ -1654,19 +1683,25 @@ class OpenAIBatchesAPI(BaseLLM):
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI] = None, client: Optional[OpenAI] = None,
): ):
openai_client: OpenAI = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client, client=client,
_is_async=_is_async,
) )
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = openai_client.batches.retrieve(**retrieve_batch_data) response = openai_client.batches.retrieve(**retrieve_batch_data)
return response return response
def cancel_batch( def cancel_batch(
self, self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest, cancel_batch_data: CancelBatchRequest,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
@ -1675,14 +1710,19 @@ class OpenAIBatchesAPI(BaseLLM):
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI] = None, client: Optional[OpenAI] = None,
): ):
openai_client: OpenAI = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
client=client, client=client,
_is_async=_is_async,
) )
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = openai_client.batches.cancel(**cancel_batch_data) response = openai_client.batches.cancel(**cancel_batch_data)
return response return response

View file

@ -89,25 +89,25 @@ async def test_async_create_batch():
batch_input_file_id is not None batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
# create_batch_response = litellm.create_batch( create_batch_response = await litellm.acreate_batch(
# completion_window="24h", completion_window="24h",
# endpoint="/v1/chat/completions", endpoint="/v1/chat/completions",
# input_file_id=batch_input_file_id, input_file_id=batch_input_file_id,
# custom_llm_provider="openai", custom_llm_provider="openai",
# metadata={"key1": "value1", "key2": "value2"}, metadata={"key1": "value1", "key2": "value2"},
# ) )
# print("response from litellm.create_batch=", create_batch_response) print("response from litellm.create_batch=", create_batch_response)
# assert ( assert (
# create_batch_response.id is not None create_batch_response.id is not None
# ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
# assert ( assert (
# create_batch_response.endpoint == "/v1/chat/completions" create_batch_response.endpoint == "/v1/chat/completions"
# ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
# assert ( assert (
# create_batch_response.input_file_id == batch_input_file_id create_batch_response.input_file_id == batch_input_file_id
# ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
# time.sleep(30) # time.sleep(30)