feat(batches): add azure openai batches endpoint support

Closes https://github.com/BerriAI/litellm/issues/5073
This commit is contained in:
Krrish Dholakia 2024-08-22 14:46:51 -07:00
parent a63c5c0020
commit 80675b22bd
7 changed files with 584 additions and 173 deletions

View file

@ -22,7 +22,8 @@ import litellm
from litellm import create_batch, create_file
def test_create_batch():
@pytest.mark.parametrize("provider", ["openai", "azure"])
def test_create_batch(provider):
"""
1. Create File for Batch completion
2. Create Batch Request
@ -35,7 +36,7 @@ def test_create_batch():
file_obj = litellm.create_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="openai",
custom_llm_provider=provider,
)
print("Response from creating file=", file_obj)
@ -44,11 +45,12 @@ def test_create_batch():
batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
time.sleep(5)
create_batch_response = litellm.create_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider="openai",
custom_llm_provider=provider,
metadata={"key1": "value1", "key2": "value2"},
)
@ -59,13 +61,14 @@ def test_create_batch():
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
assert (
create_batch_response.endpoint == "/v1/chat/completions"
or create_batch_response.endpoint == "/chat/completions"
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
assert (
create_batch_response.input_file_id == batch_input_file_id
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
retrieved_batch = litellm.retrieve_batch(
batch_id=create_batch_response.id, custom_llm_provider="openai"
batch_id=create_batch_response.id, custom_llm_provider=provider
)
print("retrieved batch=", retrieved_batch)
# just assert that we retrieved a non None batch
@ -73,11 +76,11 @@ def test_create_batch():
assert retrieved_batch.id == create_batch_response.id
# list all batches
list_batches = litellm.list_batches(custom_llm_provider="openai", limit=2)
list_batches = litellm.list_batches(custom_llm_provider=provider, limit=2)
print("list_batches=", list_batches)
file_content = litellm.file_content(
file_id=batch_input_file_id, custom_llm_provider="openai"
file_id=batch_input_file_id, custom_llm_provider=provider
)
result = file_content.content
@ -90,8 +93,9 @@ def test_create_batch():
pass
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.asyncio()
async def test_async_create_batch():
async def test_async_create_batch(provider):
"""
1. Create File for Batch completion
2. Create Batch Request
@ -105,10 +109,11 @@ async def test_async_create_batch():
file_obj = await litellm.acreate_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="openai",
custom_llm_provider=provider,
)
print("Response from creating file=", file_obj)
await asyncio.sleep(5)
batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
@ -118,7 +123,7 @@ async def test_async_create_batch():
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider="openai",
custom_llm_provider=provider,
metadata={"key1": "value1", "key2": "value2"},
)
@ -129,6 +134,7 @@ async def test_async_create_batch():
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
assert (
create_batch_response.endpoint == "/v1/chat/completions"
or create_batch_response.endpoint == "/chat/completions"
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
assert (
create_batch_response.input_file_id == batch_input_file_id
@ -137,7 +143,7 @@ async def test_async_create_batch():
await asyncio.sleep(1)
retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id, custom_llm_provider="openai"
batch_id=create_batch_response.id, custom_llm_provider=provider
)
print("retrieved batch=", retrieved_batch)
# just assert that we retrieved a non None batch
@ -145,27 +151,27 @@ async def test_async_create_batch():
assert retrieved_batch.id == create_batch_response.id
# list all batches
list_batches = await litellm.alist_batches(custom_llm_provider="openai", limit=2)
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2)
print("list_batches=", list_batches)
# try to get file content for our original file
file_content = await litellm.afile_content(
file_id=batch_input_file_id, custom_llm_provider="openai"
file_id=batch_input_file_id, custom_llm_provider=provider
)
print("file content = ", file_content)
# file obj
file_obj = await litellm.afile_retrieve(
file_id=batch_input_file_id, custom_llm_provider="openai"
file_id=batch_input_file_id, custom_llm_provider=provider
)
print("file obj = ", file_obj)
assert file_obj.id == batch_input_file_id
# delete file
delete_file_response = await litellm.afile_delete(
file_id=batch_input_file_id, custom_llm_provider="openai"
file_id=batch_input_file_id, custom_llm_provider=provider
)
print("delete file response = ", delete_file_response)
@ -173,7 +179,7 @@ async def test_async_create_batch():
assert delete_file_response.id == batch_input_file_id
all_files_list = await litellm.afile_list(
custom_llm_provider="openai",
custom_llm_provider=provider,
)
print("all_files_list = ", all_files_list)