mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat - add acreate_batch
This commit is contained in:
parent
758ed9e923
commit
1ef7cd923c
3 changed files with 143 additions and 53 deletions
|
@ -51,30 +51,33 @@ async def acreate_file(
|
||||||
|
|
||||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
try:
|
||||||
kwargs["acreate_file"] = True
|
loop = asyncio.get_event_loop()
|
||||||
|
kwargs["acreate_file"] = True
|
||||||
|
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(
|
func = partial(
|
||||||
create_file,
|
create_file,
|
||||||
file,
|
file,
|
||||||
purpose,
|
purpose,
|
||||||
custom_llm_provider,
|
custom_llm_provider,
|
||||||
extra_headers,
|
extra_headers,
|
||||||
extra_body,
|
extra_body,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
func_with_context = partial(ctx.run, func)
|
func_with_context = partial(ctx.run, func)
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
if asyncio.iscoroutine(init_response):
|
if asyncio.iscoroutine(init_response):
|
||||||
response = await init_response
|
response = await init_response
|
||||||
else:
|
else:
|
||||||
response = init_response # type: ignore
|
response = init_response # type: ignore
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def create_file(
|
def create_file(
|
||||||
|
@ -167,6 +170,52 @@ def create_file(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
async def acreate_batch(
|
||||||
|
completion_window: Literal["24h"],
|
||||||
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
|
input_file_id: str,
|
||||||
|
custom_llm_provider: Literal["openai"] = "openai",
|
||||||
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Coroutine[Any, Any, Batch]:
|
||||||
|
"""
|
||||||
|
Creates and executes a batch from an uploaded file of request
|
||||||
|
|
||||||
|
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
kwargs["acreate_batch"] = True
|
||||||
|
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(
|
||||||
|
create_batch,
|
||||||
|
completion_window,
|
||||||
|
endpoint,
|
||||||
|
input_file_id,
|
||||||
|
custom_llm_provider,
|
||||||
|
metadata,
|
||||||
|
extra_headers,
|
||||||
|
extra_body,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
response = init_response # type: ignore
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def create_batch(
|
def create_batch(
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
|
@ -176,7 +225,7 @@ def create_batch(
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Batch:
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||||
"""
|
"""
|
||||||
Creates and executes a batch from an uploaded file of request
|
Creates and executes a batch from an uploaded file of request
|
||||||
|
|
||||||
|
@ -224,6 +273,8 @@ def create_batch(
|
||||||
elif timeout is None:
|
elif timeout is None:
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||||
|
|
||||||
_create_batch_request = CreateBatchRequest(
|
_create_batch_request = CreateBatchRequest(
|
||||||
completion_window=completion_window,
|
completion_window=completion_window,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
@ -240,6 +291,7 @@ def create_batch(
|
||||||
create_batch_data=_create_batch_request,
|
create_batch_data=_create_batch_request,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -320,7 +372,10 @@ def retrieve_batch(
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||||
|
|
||||||
response = openai_batches_instance.retrieve_batch(
|
response = openai_batches_instance.retrieve_batch(
|
||||||
|
_is_async=_is_async,
|
||||||
retrieve_batch_data=_retrieve_batch_request,
|
retrieve_batch_data=_retrieve_batch_request,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -354,11 +409,6 @@ def list_batch():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Async Functions
|
|
||||||
async def acreate_batch():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def aretrieve_batch():
|
async def aretrieve_batch():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -1605,47 +1605,76 @@ class OpenAIBatchesAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI] = None,
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||||
) -> OpenAI:
|
_is_async: bool = False,
|
||||||
|
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
|
||||||
received_args = locals()
|
received_args = locals()
|
||||||
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None
|
||||||
if client is None:
|
if client is None:
|
||||||
data = {}
|
data = {}
|
||||||
for k, v in received_args.items():
|
for k, v in received_args.items():
|
||||||
if k == "self" or k == "client":
|
if k == "self" or k == "client" or k == "_is_async":
|
||||||
pass
|
pass
|
||||||
elif k == "api_base" and v is not None:
|
elif k == "api_base" and v is not None:
|
||||||
data["base_url"] = v
|
data["base_url"] = v
|
||||||
elif v is not None:
|
elif v is not None:
|
||||||
data[k] = v
|
data[k] = v
|
||||||
openai_client = OpenAI(**data) # type: ignore
|
if _is_async is True:
|
||||||
|
openai_client = AsyncOpenAI(**data)
|
||||||
|
else:
|
||||||
|
openai_client = OpenAI(**data) # type: ignore
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
|
||||||
return openai_client
|
return openai_client
|
||||||
|
|
||||||
|
async def acreate_batch(
|
||||||
|
self,
|
||||||
|
create_batch_data: CreateBatchRequest,
|
||||||
|
openai_client: AsyncOpenAI,
|
||||||
|
) -> Batch:
|
||||||
|
response = await openai_client.batches.create(**create_batch_data)
|
||||||
|
return response
|
||||||
|
|
||||||
def create_batch(
|
def create_batch(
|
||||||
self,
|
self,
|
||||||
|
_is_async: bool,
|
||||||
create_batch_data: CreateBatchRequest,
|
create_batch_data: CreateBatchRequest,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI] = None,
|
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||||
) -> Batch:
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||||
openai_client: OpenAI = self.get_openai_client(
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
if openai_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_async is True:
|
||||||
|
if not isinstance(openai_client, AsyncOpenAI):
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||||
|
)
|
||||||
|
return self.acreate_batch( # type: ignore
|
||||||
|
create_batch_data=create_batch_data, openai_client=openai_client
|
||||||
|
)
|
||||||
response = openai_client.batches.create(**create_batch_data)
|
response = openai_client.batches.create(**create_batch_data)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def retrieve_batch(
|
def retrieve_batch(
|
||||||
self,
|
self,
|
||||||
|
_is_async: bool,
|
||||||
retrieve_batch_data: RetrieveBatchRequest,
|
retrieve_batch_data: RetrieveBatchRequest,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
|
@ -1654,19 +1683,25 @@ class OpenAIBatchesAPI(BaseLLM):
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI] = None,
|
client: Optional[OpenAI] = None,
|
||||||
):
|
):
|
||||||
openai_client: OpenAI = self.get_openai_client(
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
if openai_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
response = openai_client.batches.retrieve(**retrieve_batch_data)
|
response = openai_client.batches.retrieve(**retrieve_batch_data)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def cancel_batch(
|
def cancel_batch(
|
||||||
self,
|
self,
|
||||||
|
_is_async: bool,
|
||||||
cancel_batch_data: CancelBatchRequest,
|
cancel_batch_data: CancelBatchRequest,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
|
@ -1675,14 +1710,19 @@ class OpenAIBatchesAPI(BaseLLM):
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI] = None,
|
client: Optional[OpenAI] = None,
|
||||||
):
|
):
|
||||||
openai_client: OpenAI = self.get_openai_client(
|
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
if openai_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
response = openai_client.batches.cancel(**cancel_batch_data)
|
response = openai_client.batches.cancel(**cancel_batch_data)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -89,25 +89,25 @@ async def test_async_create_batch():
|
||||||
batch_input_file_id is not None
|
batch_input_file_id is not None
|
||||||
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
# create_batch_response = litellm.create_batch(
|
create_batch_response = await litellm.acreate_batch(
|
||||||
# completion_window="24h",
|
completion_window="24h",
|
||||||
# endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
# input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
# custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
# metadata={"key1": "value1", "key2": "value2"},
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
# )
|
)
|
||||||
|
|
||||||
# print("response from litellm.create_batch=", create_batch_response)
|
print("response from litellm.create_batch=", create_batch_response)
|
||||||
|
|
||||||
# assert (
|
assert (
|
||||||
# create_batch_response.id is not None
|
create_batch_response.id is not None
|
||||||
# ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
||||||
# assert (
|
assert (
|
||||||
# create_batch_response.endpoint == "/v1/chat/completions"
|
create_batch_response.endpoint == "/v1/chat/completions"
|
||||||
# ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
||||||
# assert (
|
assert (
|
||||||
# create_batch_response.input_file_id == batch_input_file_id
|
create_batch_response.input_file_id == batch_input_file_id
|
||||||
# ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||||
|
|
||||||
# time.sleep(30)
|
# time.sleep(30)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue