feat(router.py): Support Loadbalancing batch azure api endpoints (#5469)

* feat(router.py): initial commit for loadbalancing azure batch api endpoints

Closes https://github.com/BerriAI/litellm/issues/5396

* fix(router.py): working `router.acreate_file()`

* feat(router.py): working router.acreate_batch endpoint

* feat(router.py): expose router.aretrieve_batch function

Make it easy for user to retrieve the batch information

* feat(router.py): support 'router.alist_batches' endpoint

Adds support for getting all batches across all endpoints

* feat(router.py): working loadbalancing on `/v1/files`

* feat(proxy_server.py): working loadbalancing on `/v1/batches`

* feat(proxy_server.py): working loadbalancing on Retrieve + List batch
This commit is contained in:
Krish Dholakia 2024-09-02 21:32:55 -07:00 committed by GitHub
parent 9b22359bed
commit 18da7adce9
10 changed files with 667 additions and 37 deletions

View file

@ -2394,3 +2394,83 @@ async def test_router_weighted_pick(sync_mode):
else:
raise Exception("invalid model id returned!")
assert model_id_1_count > model_id_2_count
@pytest.mark.parametrize("provider", ["azure"])
@pytest.mark.asyncio
async def test_router_batch_endpoints(provider):
"""
1. Create File for Batch completion
2. Create Batch Request
3. Retrieve the specific batch
"""
print("Testing async create batch")
router = Router(
model_list=[
{
"model_name": "my-custom-name",
"litellm_params": {
"model": "azure/gpt-4o-mini",
"api_base": os.getenv("AZURE_API_BASE"),
"api_key": os.getenv("AZURE_API_KEY"),
},
},
]
)
file_name = "openai_batch_completions_router.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_obj = await router.acreate_file(
model="my-custom-name",
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider=provider,
)
print("Response from creating file=", file_obj)
await asyncio.sleep(10)
batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
create_batch_response = await router.acreate_batch(
model="my-custom-name",
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider=provider,
metadata={"key1": "value1", "key2": "value2"},
)
print("response from router.create_batch=", create_batch_response)
assert (
create_batch_response.id is not None
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
assert (
create_batch_response.endpoint == "/v1/chat/completions"
or create_batch_response.endpoint == "/chat/completions"
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
assert (
create_batch_response.input_file_id == batch_input_file_id
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
await asyncio.sleep(1)
retrieved_batch = await router.aretrieve_batch(
batch_id=create_batch_response.id,
custom_llm_provider=provider,
)
print("retrieved batch=", retrieved_batch)
# just assert that we retrieved a non None batch
assert retrieved_batch.id == create_batch_response.id
# list all batches
list_batches = await router.alist_batches(
model="my-custom-name", custom_llm_provider=provider, limit=2
)
print("list_batches=", list_batches)