mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(router.py): Support Loadbalancing batch azure api endpoints (#5469)
* feat(router.py): initial commit for loadbalancing azure batch api endpoints Closes https://github.com/BerriAI/litellm/issues/5396 * fix(router.py): working `router.acreate_file()` * feat(router.py): working router.acreate_batch endpoint * feat(router.py): expose router.aretrieve_batch function Make it easy for user to retrieve the batch information * feat(router.py): support 'router.alist_batches' endpoint Adds support for getting all batches across all endpoints * feat(router.py): working loadbalancing on `/v1/files` * feat(proxy_server.py): working loadbalancing on `/v1/batches` * feat(proxy_server.py): working loadbalancing on Retrieve + List batch
This commit is contained in:
parent
9b22359bed
commit
18da7adce9
10 changed files with 667 additions and 37 deletions
|
@ -54,6 +54,10 @@ from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
|||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||
from litellm.router_utils.batch_utils import (
|
||||
_get_router_metadata_variable_name,
|
||||
replace_model_in_jsonl,
|
||||
)
|
||||
from litellm.router_utils.client_initalization_utils import (
|
||||
set_client,
|
||||
should_initialize_sync_client,
|
||||
|
@ -73,6 +77,12 @@ from litellm.types.llms.openai import (
|
|||
AssistantToolParam,
|
||||
AsyncCursorPage,
|
||||
Attachment,
|
||||
Batch,
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileObject,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIMessage,
|
||||
Run,
|
||||
Thread,
|
||||
|
@ -103,6 +113,7 @@ from litellm.utils import (
|
|||
_is_region_eu,
|
||||
calculate_max_parallel_requests,
|
||||
create_proxy_transport_and_mounts,
|
||||
get_llm_provider,
|
||||
get_utc_datetime,
|
||||
)
|
||||
|
||||
|
@ -2228,6 +2239,373 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
#### FILES API ####
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_file
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
send_llm_exception_alert(
|
||||
litellm_router_instance=self,
|
||||
request_kwargs=kwargs,
|
||||
error_traceback_str=traceback.format_exc(),
|
||||
original_exception=e,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _acreate_file(
|
||||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "files-api-fake-text"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
"api_base": deployment.get("litellm_params", {}).get("api_base"),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ##
|
||||
stripped_model, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=data["model"]
|
||||
)
|
||||
kwargs["file"] = replace_model_in_jsonl(
|
||||
file_content=kwargs["file"], new_model_name=stripped_model
|
||||
)
|
||||
|
||||
response = litellm.acreate_file(
|
||||
**{
|
||||
**data,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
client_type="max_parallel_requests",
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response # type: ignore
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response # type: ignore
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response # type: ignore
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
f"litellm.acreate_file(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model is not None:
|
||||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
async def acreate_batch(
|
||||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["original_function"] = self._acreate_batch
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
send_llm_exception_alert(
|
||||
litellm_router_instance=self,
|
||||
request_kwargs=kwargs,
|
||||
error_traceback_str=traceback.format_exc(),
|
||||
original_exception=e,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _acreate_batch(
|
||||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "files-api-fake-text"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
metadata_variable_name = _get_router_metadata_variable_name(
|
||||
function_name="_acreate_batch"
|
||||
)
|
||||
|
||||
kwargs.setdefault(metadata_variable_name, {}).update(
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
"api_base": deployment.get("litellm_params", {}).get("api_base"),
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == metadata_variable_name:
|
||||
kwargs[k].update(v)
|
||||
|
||||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
dynamic_api_key is not None
|
||||
and potential_model_client is not None
|
||||
and dynamic_api_key != potential_model_client.api_key
|
||||
):
|
||||
model_client = None
|
||||
else:
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=data["model"])
|
||||
|
||||
response = litellm.acreate_batch(
|
||||
**{
|
||||
**data,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
"timeout": self.timeout,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
client_type="max_parallel_requests",
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response # type: ignore
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response # type: ignore
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response # type: ignore
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
f"litellm._acreate_batch(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model is not None:
|
||||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
async def aretrieve_batch(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
"""
|
||||
Iterate through all models in a model group to check for batch
|
||||
|
||||
Future Improvement - cache the result.
|
||||
"""
|
||||
try:
|
||||
|
||||
filtered_model_list = self.get_model_list()
|
||||
if filtered_model_list is None:
|
||||
raise Exception("Router not yet initialized.")
|
||||
|
||||
receieved_exceptions = []
|
||||
|
||||
async def try_retrieve_batch(model_name):
|
||||
try:
|
||||
# Update kwargs with the current model name or any other model-specific adjustments
|
||||
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model_name["litellm_params"]["model"]
|
||||
)
|
||||
new_kwargs = copy.deepcopy(kwargs)
|
||||
new_kwargs.pop("custom_llm_provider", None)
|
||||
return await litellm.aretrieve_batch(
|
||||
custom_llm_provider=custom_llm_provider, **new_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
receieved_exceptions.append(e)
|
||||
return None
|
||||
|
||||
# Check all models in parallel
|
||||
results = await asyncio.gather(
|
||||
*[try_retrieve_batch(model) for model in filtered_model_list],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Check for successful responses and handle exceptions
|
||||
for result in results:
|
||||
if isinstance(result, Batch):
|
||||
return result
|
||||
|
||||
# If no valid Batch response was found, raise the first encountered exception
|
||||
if receieved_exceptions:
|
||||
raise receieved_exceptions[0] # Raising the first exception encountered
|
||||
|
||||
# If no exceptions were encountered, raise a generic exception
|
||||
raise Exception(
|
||||
"Unable to find batch in any model. Received errors - {}".format(
|
||||
receieved_exceptions
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
send_llm_exception_alert(
|
||||
litellm_router_instance=self,
|
||||
request_kwargs=kwargs,
|
||||
error_traceback_str=traceback.format_exc(),
|
||||
original_exception=e,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
async def alist_batches(
|
||||
self,
|
||||
model: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Return all the batches across all deployments of a model group.
|
||||
"""
|
||||
|
||||
filtered_model_list = self.get_model_list(model_name=model)
|
||||
if filtered_model_list is None:
|
||||
raise Exception("Router not yet initialized.")
|
||||
|
||||
async def try_retrieve_batch(model: DeploymentTypedDict):
|
||||
try:
|
||||
# Update kwargs with the current model name or any other model-specific adjustments
|
||||
return await litellm.alist_batches(
|
||||
**{**model["litellm_params"], **kwargs}
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# Check all models in parallel
|
||||
results = await asyncio.gather(
|
||||
*[try_retrieve_batch(model) for model in filtered_model_list]
|
||||
)
|
||||
|
||||
final_results = {
|
||||
"object": "list",
|
||||
"data": [],
|
||||
"first_id": None,
|
||||
"last_id": None,
|
||||
"has_more": False,
|
||||
}
|
||||
|
||||
for result in results:
|
||||
if result is not None:
|
||||
## check batch id
|
||||
if final_results["first_id"] is None:
|
||||
final_results["first_id"] = result.first_id
|
||||
final_results["last_id"] = result.last_id
|
||||
final_results["data"].extend(result.data) # type: ignore
|
||||
|
||||
## check 'has_more'
|
||||
if result.has_more is True:
|
||||
final_results["has_more"] = True
|
||||
|
||||
return final_results
|
||||
|
||||
#### ASSISTANTS API ####
|
||||
|
||||
async def acreate_assistants(
|
||||
|
@ -4132,9 +4510,18 @@ class Router:
|
|||
def get_model_names(self) -> List[str]:
|
||||
return self.model_names
|
||||
|
||||
def get_model_list(self):
|
||||
def get_model_list(
|
||||
self, model_name: Optional[str] = None
|
||||
) -> Optional[List[DeploymentTypedDict]]:
|
||||
if hasattr(self, "model_list"):
|
||||
return self.model_list
|
||||
if model_name is None:
|
||||
return self.model_list
|
||||
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
for model in self.model_list:
|
||||
if model["model_name"] == model_name:
|
||||
returned_models.append(model)
|
||||
return returned_models
|
||||
return None
|
||||
|
||||
def get_model_access_groups(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue