feat(router.py): Support Loadbalancing batch azure api endpoints (#5469)

* feat(router.py): initial commit for loadbalancing azure batch api endpoints

Closes https://github.com/BerriAI/litellm/issues/5396

* fix(router.py): working `router.acreate_file()`

* feat(router.py): working router.acreate_batch endpoint

* feat(router.py): expose router.aretrieve_batch function

Make it easy for user to retrieve the batch information

* feat(router.py): support 'router.alist_batches' endpoint

Adds support for getting all batches across all endpoints

* feat(router.py): working loadbalancing on `/v1/files`

* feat(proxy_server.py): working loadbalancing on `/v1/batches`

* feat(proxy_server.py): working loadbalancing on Retrieve + List batch
This commit is contained in:
Krish Dholakia 2024-09-02 21:32:55 -07:00 committed by GitHub
parent 9b22359bed
commit 18da7adce9
10 changed files with 667 additions and 37 deletions

View file

@ -31,6 +31,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.batches.main import FileObject
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.router import Router
router = APIRouter()
@ -66,6 +67,41 @@ def get_files_provider_config(
return None
def get_first_json_object(file_content_bytes: bytes) -> Optional[dict]:
try:
# Decode the bytes to a string and split into lines
file_content = file_content_bytes.decode("utf-8")
first_line = file_content.splitlines()[0].strip()
# Parse the JSON object from the first line
json_object = json.loads(first_line)
return json_object
except (json.JSONDecodeError, UnicodeDecodeError) as e:
return None
def get_model_from_json_obj(json_object: dict) -> Optional[str]:
body = json_object.get("body", {}) or {}
model = body.get("model")
return model
def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool:
"""
Returns True if the model is in the llm_router model names
"""
if model is None or llm_router is None:
return False
model_names = llm_router.get_model_names()
is_in_list = False
if model in model_names:
is_in_list = True
return is_in_list
@router.post(
"/{provider}/v1/files",
dependencies=[Depends(user_api_key_auth)],
@ -109,6 +145,7 @@ async def create_file(
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
version,
@ -138,18 +175,46 @@ async def create_file(
# Prepare the file data according to FileTypes
file_data = (file.filename, file_content, file.content_type)
## check if model is a loadbalanced model
router_model: Optional[str] = None
is_router_model = False
if litellm.enable_loadbalancing_on_batch_endpoints is True:
json_obj = get_first_json_object(file_content_bytes=file_content)
if json_obj:
router_model = get_model_from_json_obj(json_object=json_obj)
is_router_model = is_known_model(
model=router_model, llm_router=llm_router
)
_create_file_request = CreateFileRequest(file=file_data, **data)
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
response = await llm_router.acreate_file(
model=router_model, **_create_file_request
)
else:
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request) # type: ignore
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request) # type: ignore
### ALERTING ###
asyncio.create_task(