mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(router.py): Support Loadbalancing batch azure api endpoints (#5469)
* feat(router.py): initial commit for loadbalancing azure batch api endpoints Closes https://github.com/BerriAI/litellm/issues/5396 * fix(router.py): working `router.acreate_file()` * feat(router.py): working router.acreate_batch endpoint * feat(router.py): expose router.aretrieve_batch function Make it easy for user to retrieve the batch information * feat(router.py): support 'router.alist_batches' endpoint Adds support for getting all batches across all endpoints * feat(router.py): working loadbalancing on `/v1/files` * feat(proxy_server.py): working loadbalancing on `/v1/batches` * feat(proxy_server.py): working loadbalancing on Retrieve + List batch
This commit is contained in:
parent
9b22359bed
commit
18da7adce9
10 changed files with 667 additions and 37 deletions
|
@ -140,6 +140,7 @@ return_response_headers: bool = (
|
||||||
enable_json_schema_validation: bool = False
|
enable_json_schema_validation: bool = False
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
|
enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
||||||
enable_caching_on_provider_specific_optional_params: bool = (
|
enable_caching_on_provider_specific_optional_params: bool = (
|
||||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,17 +1,12 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "gpt-3.5-turbo"
|
- model_name: "batch-gpt-4o-mini"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "gpt-3.5-turbo"
|
model: "azure/gpt-4o-mini"
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
model_info:
|
||||||
|
mode: batch
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
max_internal_user_budget: 0.02 # amount in USD
|
enable_loadbalancing_on_batch_endpoints: true
|
||||||
internal_user_budget_duration: "1s" # reset every second
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
alerting: ["slack"]
|
|
||||||
alerting_threshold: 0.0001 # (Seconds) set an artifically low threshold for testing alerting
|
|
||||||
alert_to_webhook_url: {
|
|
||||||
"spend_reports": ["https://webhook.site/7843a980-a494-4967-80fb-d502dbc16886", "https://webhook.site/28cfb179-f4fb-4408-8129-729ff55cf213"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.batches.main import FileObject
|
from litellm.batches.main import FileObject
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -66,6 +67,41 @@ def get_files_provider_config(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_json_object(file_content_bytes: bytes) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
# Decode the bytes to a string and split into lines
|
||||||
|
file_content = file_content_bytes.decode("utf-8")
|
||||||
|
first_line = file_content.splitlines()[0].strip()
|
||||||
|
|
||||||
|
# Parse the JSON object from the first line
|
||||||
|
json_object = json.loads(first_line)
|
||||||
|
return json_object
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_from_json_obj(json_object: dict) -> Optional[str]:
|
||||||
|
body = json_object.get("body", {}) or {}
|
||||||
|
model = body.get("model")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def is_known_model(model: Optional[str], llm_router: Optional[Router]) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the model is in the llm_router model names
|
||||||
|
"""
|
||||||
|
if model is None or llm_router is None:
|
||||||
|
return False
|
||||||
|
model_names = llm_router.get_model_names()
|
||||||
|
|
||||||
|
is_in_list = False
|
||||||
|
if model in model_names:
|
||||||
|
is_in_list = True
|
||||||
|
|
||||||
|
return is_in_list
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/{provider}/v1/files",
|
"/{provider}/v1/files",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
@ -109,6 +145,7 @@ async def create_file(
|
||||||
add_litellm_data_to_request,
|
add_litellm_data_to_request,
|
||||||
general_settings,
|
general_settings,
|
||||||
get_custom_headers,
|
get_custom_headers,
|
||||||
|
llm_router,
|
||||||
proxy_config,
|
proxy_config,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
version,
|
version,
|
||||||
|
@ -138,8 +175,36 @@ async def create_file(
|
||||||
# Prepare the file data according to FileTypes
|
# Prepare the file data according to FileTypes
|
||||||
file_data = (file.filename, file_content, file.content_type)
|
file_data = (file.filename, file_content, file.content_type)
|
||||||
|
|
||||||
|
## check if model is a loadbalanced model
|
||||||
|
router_model: Optional[str] = None
|
||||||
|
is_router_model = False
|
||||||
|
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||||
|
json_obj = get_first_json_object(file_content_bytes=file_content)
|
||||||
|
if json_obj:
|
||||||
|
router_model = get_model_from_json_obj(json_object=json_obj)
|
||||||
|
is_router_model = is_known_model(
|
||||||
|
model=router_model, llm_router=llm_router
|
||||||
|
)
|
||||||
|
|
||||||
_create_file_request = CreateFileRequest(file=file_data, **data)
|
_create_file_request = CreateFileRequest(file=file_data, **data)
|
||||||
|
|
||||||
|
if (
|
||||||
|
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||||
|
and is_router_model
|
||||||
|
and router_model is not None
|
||||||
|
):
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await llm_router.acreate_file(
|
||||||
|
model=router_model, **_create_file_request
|
||||||
|
)
|
||||||
|
else:
|
||||||
# get configs for custom_llm_provider
|
# get configs for custom_llm_provider
|
||||||
llm_provider_config = get_files_provider_config(
|
llm_provider_config = get_files_provider_config(
|
||||||
custom_llm_provider=custom_llm_provider
|
custom_llm_provider=custom_llm_provider
|
||||||
|
|
|
@ -199,6 +199,7 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
|
||||||
router as team_callback_router,
|
router as team_callback_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||||
|
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
||||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||||
router as openai_files_router,
|
router as openai_files_router,
|
||||||
)
|
)
|
||||||
|
@ -4979,8 +4980,30 @@ async def create_batch(
|
||||||
proxy_config=proxy_config,
|
proxy_config=proxy_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## check if model is a loadbalanced model
|
||||||
|
router_model: Optional[str] = None
|
||||||
|
is_router_model = False
|
||||||
|
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||||
|
router_model = data.get("model", None)
|
||||||
|
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
||||||
|
|
||||||
_create_batch_data = CreateBatchRequest(**data)
|
_create_batch_data = CreateBatchRequest(**data)
|
||||||
|
|
||||||
|
if (
|
||||||
|
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||||
|
and is_router_model
|
||||||
|
and router_model is not None
|
||||||
|
):
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
||||||
|
else:
|
||||||
if provider is None:
|
if provider is None:
|
||||||
provider = "openai"
|
provider = "openai"
|
||||||
response = await litellm.acreate_batch(
|
response = await litellm.acreate_batch(
|
||||||
|
@ -5017,7 +5040,7 @@ async def create_batch(
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
|
@ -5080,10 +5103,25 @@ async def retrieve_batch(
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
data: Dict = {}
|
data: Dict = {}
|
||||||
try:
|
try:
|
||||||
|
## check if model is a loadbalanced model
|
||||||
|
router_model: Optional[str] = None
|
||||||
|
is_router_model = False
|
||||||
|
|
||||||
_retrieve_batch_request = RetrieveBatchRequest(
|
_retrieve_batch_request = RetrieveBatchRequest(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
|
||||||
|
else:
|
||||||
if provider is None:
|
if provider is None:
|
||||||
provider = "openai"
|
provider = "openai"
|
||||||
response = await litellm.aretrieve_batch(
|
response = await litellm.aretrieve_batch(
|
||||||
|
@ -5120,7 +5158,7 @@ async def retrieve_batch(
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
|
|
|
@ -54,6 +54,10 @@ from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
|
||||||
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
|
||||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||||
|
from litellm.router_utils.batch_utils import (
|
||||||
|
_get_router_metadata_variable_name,
|
||||||
|
replace_model_in_jsonl,
|
||||||
|
)
|
||||||
from litellm.router_utils.client_initalization_utils import (
|
from litellm.router_utils.client_initalization_utils import (
|
||||||
set_client,
|
set_client,
|
||||||
should_initialize_sync_client,
|
should_initialize_sync_client,
|
||||||
|
@ -73,6 +77,12 @@ from litellm.types.llms.openai import (
|
||||||
AssistantToolParam,
|
AssistantToolParam,
|
||||||
AsyncCursorPage,
|
AsyncCursorPage,
|
||||||
Attachment,
|
Attachment,
|
||||||
|
Batch,
|
||||||
|
CreateFileRequest,
|
||||||
|
FileContentRequest,
|
||||||
|
FileObject,
|
||||||
|
FileTypes,
|
||||||
|
HttpxBinaryResponseContent,
|
||||||
OpenAIMessage,
|
OpenAIMessage,
|
||||||
Run,
|
Run,
|
||||||
Thread,
|
Thread,
|
||||||
|
@ -103,6 +113,7 @@ from litellm.utils import (
|
||||||
_is_region_eu,
|
_is_region_eu,
|
||||||
calculate_max_parallel_requests,
|
calculate_max_parallel_requests,
|
||||||
create_proxy_transport_and_mounts,
|
create_proxy_transport_and_mounts,
|
||||||
|
get_llm_provider,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2228,6 +2239,373 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
#### FILES API ####
|
||||||
|
|
||||||
|
async def acreate_file(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> FileObject:
|
||||||
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["original_function"] = self._acreate_file
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(
|
||||||
|
send_llm_exception_alert(
|
||||||
|
litellm_router_instance=self,
|
||||||
|
request_kwargs=kwargs,
|
||||||
|
error_traceback_str=traceback.format_exc(),
|
||||||
|
original_exception=e,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _acreate_file(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> FileObject:
|
||||||
|
try:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
deployment = await self.async_get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": "files-api-fake-text"}],
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
kwargs.setdefault("metadata", {}).update(
|
||||||
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
"api_base": deployment.get("litellm_params", {}).get("api_base"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
potential_model_client = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
|
)
|
||||||
|
# check if provided keys == client keys #
|
||||||
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
|
if (
|
||||||
|
dynamic_api_key is not None
|
||||||
|
and potential_model_client is not None
|
||||||
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
|
):
|
||||||
|
model_client = None
|
||||||
|
else:
|
||||||
|
model_client = potential_model_client
|
||||||
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
|
## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ##
|
||||||
|
stripped_model, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=data["model"]
|
||||||
|
)
|
||||||
|
kwargs["file"] = replace_model_in_jsonl(
|
||||||
|
file_content=kwargs["file"], new_model_name=stripped_model
|
||||||
|
)
|
||||||
|
|
||||||
|
response = litellm.acreate_file(
|
||||||
|
**{
|
||||||
|
**data,
|
||||||
|
"custom_llm_provider": custom_llm_provider,
|
||||||
|
"caching": self.cache_responses,
|
||||||
|
"client": model_client,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rpm_semaphore = self._get_client(
|
||||||
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
rpm_semaphore, asyncio.Semaphore
|
||||||
|
):
|
||||||
|
async with rpm_semaphore:
|
||||||
|
"""
|
||||||
|
- Check rpm limits before making the call
|
||||||
|
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||||
|
"""
|
||||||
|
await self.async_routing_strategy_pre_call_checks(
|
||||||
|
deployment=deployment
|
||||||
|
)
|
||||||
|
response = await response # type: ignore
|
||||||
|
else:
|
||||||
|
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||||
|
response = await response # type: ignore
|
||||||
|
|
||||||
|
self.success_calls[model_name] += 1
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
)
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.exception(
|
||||||
|
f"litellm.acreate_file(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m"
|
||||||
|
)
|
||||||
|
if model is not None:
|
||||||
|
self.fail_calls[model] += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def acreate_batch(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> Batch:
|
||||||
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["original_function"] = self._acreate_batch
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(
|
||||||
|
send_llm_exception_alert(
|
||||||
|
litellm_router_instance=self,
|
||||||
|
request_kwargs=kwargs,
|
||||||
|
error_traceback_str=traceback.format_exc(),
|
||||||
|
original_exception=e,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _acreate_batch(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> Batch:
|
||||||
|
try:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
deployment = await self.async_get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": "files-api-fake-text"}],
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
metadata_variable_name = _get_router_metadata_variable_name(
|
||||||
|
function_name="_acreate_batch"
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs.setdefault(metadata_variable_name, {}).update(
|
||||||
|
{
|
||||||
|
"deployment": deployment["litellm_params"]["model"],
|
||||||
|
"model_info": deployment.get("model_info", {}),
|
||||||
|
"api_base": deployment.get("litellm_params", {}).get("api_base"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if (
|
||||||
|
k not in kwargs
|
||||||
|
): # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == metadata_variable_name:
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
potential_model_client = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
|
)
|
||||||
|
# check if provided keys == client keys #
|
||||||
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
|
if (
|
||||||
|
dynamic_api_key is not None
|
||||||
|
and potential_model_client is not None
|
||||||
|
and dynamic_api_key != potential_model_client.api_key
|
||||||
|
):
|
||||||
|
model_client = None
|
||||||
|
else:
|
||||||
|
model_client = potential_model_client
|
||||||
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
|
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider(model=data["model"])
|
||||||
|
|
||||||
|
response = litellm.acreate_batch(
|
||||||
|
**{
|
||||||
|
**data,
|
||||||
|
"custom_llm_provider": custom_llm_provider,
|
||||||
|
"caching": self.cache_responses,
|
||||||
|
"client": model_client,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rpm_semaphore = self._get_client(
|
||||||
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
rpm_semaphore, asyncio.Semaphore
|
||||||
|
):
|
||||||
|
async with rpm_semaphore:
|
||||||
|
"""
|
||||||
|
- Check rpm limits before making the call
|
||||||
|
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||||
|
"""
|
||||||
|
await self.async_routing_strategy_pre_call_checks(
|
||||||
|
deployment=deployment
|
||||||
|
)
|
||||||
|
response = await response # type: ignore
|
||||||
|
else:
|
||||||
|
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||||
|
response = await response # type: ignore
|
||||||
|
|
||||||
|
self.success_calls[model_name] += 1
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"litellm.acreate_file(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
)
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.exception(
|
||||||
|
f"litellm._acreate_batch(model={model}, {kwargs})\033[31m Exception {str(e)}\033[0m"
|
||||||
|
)
|
||||||
|
if model is not None:
|
||||||
|
self.fail_calls[model] += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def aretrieve_batch(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
) -> Batch:
|
||||||
|
"""
|
||||||
|
Iterate through all models in a model group to check for batch
|
||||||
|
|
||||||
|
Future Improvement - cache the result.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
filtered_model_list = self.get_model_list()
|
||||||
|
if filtered_model_list is None:
|
||||||
|
raise Exception("Router not yet initialized.")
|
||||||
|
|
||||||
|
receieved_exceptions = []
|
||||||
|
|
||||||
|
async def try_retrieve_batch(model_name):
|
||||||
|
try:
|
||||||
|
# Update kwargs with the current model name or any other model-specific adjustments
|
||||||
|
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=model_name["litellm_params"]["model"]
|
||||||
|
)
|
||||||
|
new_kwargs = copy.deepcopy(kwargs)
|
||||||
|
new_kwargs.pop("custom_llm_provider", None)
|
||||||
|
return await litellm.aretrieve_batch(
|
||||||
|
custom_llm_provider=custom_llm_provider, **new_kwargs
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
receieved_exceptions.append(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check all models in parallel
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[try_retrieve_batch(model) for model in filtered_model_list],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for successful responses and handle exceptions
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Batch):
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If no valid Batch response was found, raise the first encountered exception
|
||||||
|
if receieved_exceptions:
|
||||||
|
raise receieved_exceptions[0] # Raising the first exception encountered
|
||||||
|
|
||||||
|
# If no exceptions were encountered, raise a generic exception
|
||||||
|
raise Exception(
|
||||||
|
"Unable to find batch in any model. Received errors - {}".format(
|
||||||
|
receieved_exceptions
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
asyncio.create_task(
|
||||||
|
send_llm_exception_alert(
|
||||||
|
litellm_router_instance=self,
|
||||||
|
request_kwargs=kwargs,
|
||||||
|
error_traceback_str=traceback.format_exc(),
|
||||||
|
original_exception=e,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def alist_batches(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return all the batches across all deployments of a model group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filtered_model_list = self.get_model_list(model_name=model)
|
||||||
|
if filtered_model_list is None:
|
||||||
|
raise Exception("Router not yet initialized.")
|
||||||
|
|
||||||
|
async def try_retrieve_batch(model: DeploymentTypedDict):
|
||||||
|
try:
|
||||||
|
# Update kwargs with the current model name or any other model-specific adjustments
|
||||||
|
return await litellm.alist_batches(
|
||||||
|
**{**model["litellm_params"], **kwargs}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check all models in parallel
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[try_retrieve_batch(model) for model in filtered_model_list]
|
||||||
|
)
|
||||||
|
|
||||||
|
final_results = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [],
|
||||||
|
"first_id": None,
|
||||||
|
"last_id": None,
|
||||||
|
"has_more": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
if result is not None:
|
||||||
|
## check batch id
|
||||||
|
if final_results["first_id"] is None:
|
||||||
|
final_results["first_id"] = result.first_id
|
||||||
|
final_results["last_id"] = result.last_id
|
||||||
|
final_results["data"].extend(result.data) # type: ignore
|
||||||
|
|
||||||
|
## check 'has_more'
|
||||||
|
if result.has_more is True:
|
||||||
|
final_results["has_more"] = True
|
||||||
|
|
||||||
|
return final_results
|
||||||
|
|
||||||
#### ASSISTANTS API ####
|
#### ASSISTANTS API ####
|
||||||
|
|
||||||
async def acreate_assistants(
|
async def acreate_assistants(
|
||||||
|
@ -4132,9 +4510,18 @@ class Router:
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self) -> List[str]:
|
||||||
return self.model_names
|
return self.model_names
|
||||||
|
|
||||||
def get_model_list(self):
|
def get_model_list(
|
||||||
|
self, model_name: Optional[str] = None
|
||||||
|
) -> Optional[List[DeploymentTypedDict]]:
|
||||||
if hasattr(self, "model_list"):
|
if hasattr(self, "model_list"):
|
||||||
|
if model_name is None:
|
||||||
return self.model_list
|
return self.model_list
|
||||||
|
|
||||||
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
for model in self.model_list:
|
||||||
|
if model["model_name"] == model_name:
|
||||||
|
returned_models.append(model)
|
||||||
|
return returned_models
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_model_access_groups(self):
|
def get_model_access_groups(self):
|
||||||
|
|
59
litellm/router_utils/batch_utils.py
Normal file
59
litellm/router_utils/batch_utils.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
from typing import IO, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryFile(io.BytesIO):
|
||||||
|
def __init__(self, content: bytes, name: str):
|
||||||
|
super().__init__(content)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
def replace_model_in_jsonl(
|
||||||
|
file_content: Union[bytes, IO, Tuple[str, bytes, str]], new_model_name: str
|
||||||
|
) -> Optional[InMemoryFile]:
|
||||||
|
try:
|
||||||
|
# Decode the bytes to a string and split into lines
|
||||||
|
# If file_content is a file-like object, read the bytes
|
||||||
|
if hasattr(file_content, "read"):
|
||||||
|
file_content_bytes = file_content.read() # type: ignore
|
||||||
|
elif isinstance(file_content, tuple):
|
||||||
|
file_content_bytes = file_content[1]
|
||||||
|
else:
|
||||||
|
file_content_bytes = file_content
|
||||||
|
|
||||||
|
# Decode the bytes to a string and split into lines
|
||||||
|
file_content_str = file_content_bytes.decode("utf-8")
|
||||||
|
lines = file_content_str.splitlines()
|
||||||
|
modified_lines = []
|
||||||
|
for line in lines:
|
||||||
|
# Parse each line as a JSON object
|
||||||
|
json_object = json.loads(line.strip())
|
||||||
|
|
||||||
|
# Replace the model name if it exists
|
||||||
|
if "body" in json_object:
|
||||||
|
json_object["body"]["model"] = new_model_name
|
||||||
|
|
||||||
|
# Convert the modified JSON object back to a string
|
||||||
|
modified_lines.append(json.dumps(json_object))
|
||||||
|
|
||||||
|
# Reassemble the modified lines and return as bytes
|
||||||
|
modified_file_content = "\n".join(modified_lines).encode("utf-8")
|
||||||
|
return InMemoryFile(modified_file_content, name="modified_file.jsonl") # type: ignore
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError, TypeError) as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_router_metadata_variable_name(function_name) -> str:
|
||||||
|
"""
|
||||||
|
Helper to return what the "metadata" field should be called in the request data
|
||||||
|
|
||||||
|
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
|
||||||
|
|
||||||
|
For ALL other endpoints we call this "metadata
|
||||||
|
"""
|
||||||
|
if "batch" in function_name:
|
||||||
|
return "litellm_metadata"
|
||||||
|
else:
|
||||||
|
return "metadata"
|
3
litellm/tests/openai_batch_completions_router.jsonl
Normal file
3
litellm/tests/openai_batch_completions_router.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||||
|
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||||
|
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
|
@ -2394,3 +2394,83 @@ async def test_router_weighted_pick(sync_mode):
|
||||||
else:
|
else:
|
||||||
raise Exception("invalid model id returned!")
|
raise Exception("invalid model id returned!")
|
||||||
assert model_id_1_count > model_id_2_count
|
assert model_id_1_count > model_id_2_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["azure"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_batch_endpoints(provider):
|
||||||
|
"""
|
||||||
|
1. Create File for Batch completion
|
||||||
|
2. Create Batch Request
|
||||||
|
3. Retrieve the specific batch
|
||||||
|
"""
|
||||||
|
print("Testing async create batch")
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "my-custom-name",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-4o-mini",
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
file_name = "openai_batch_completions_router.jsonl"
|
||||||
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
file_obj = await router.acreate_file(
|
||||||
|
model="my-custom-name",
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="batch",
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
)
|
||||||
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
batch_input_file_id = file_obj.id
|
||||||
|
assert (
|
||||||
|
batch_input_file_id is not None
|
||||||
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
|
create_batch_response = await router.acreate_batch(
|
||||||
|
model="my-custom-name",
|
||||||
|
completion_window="24h",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=batch_input_file_id,
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("response from router.create_batch=", create_batch_response)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
create_batch_response.id is not None
|
||||||
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
||||||
|
assert (
|
||||||
|
create_batch_response.endpoint == "/v1/chat/completions"
|
||||||
|
or create_batch_response.endpoint == "/chat/completions"
|
||||||
|
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
||||||
|
assert (
|
||||||
|
create_batch_response.input_file_id == batch_input_file_id
|
||||||
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
retrieved_batch = await router.aretrieve_batch(
|
||||||
|
batch_id=create_batch_response.id,
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
)
|
||||||
|
print("retrieved batch=", retrieved_batch)
|
||||||
|
# just assert that we retrieved a non None batch
|
||||||
|
|
||||||
|
assert retrieved_batch.id == create_batch_response.id
|
||||||
|
|
||||||
|
# list all batches
|
||||||
|
list_batches = await router.alist_batches(
|
||||||
|
model="my-custom-name", custom_llm_provider=provider, limit=2
|
||||||
|
)
|
||||||
|
print("list_batches=", list_batches)
|
||||||
|
|
|
@ -4645,6 +4645,8 @@ def get_llm_provider(
|
||||||
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
|
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
|
||||||
|
|
||||||
Raises Error - if unable to map model to a provider
|
Raises Error - if unable to map model to a provider
|
||||||
|
|
||||||
|
Return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
## IF LITELLM PARAMS GIVEN ##
|
## IF LITELLM PARAMS GIVEN ##
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "my-custom-name", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
Loading…
Add table
Add a link
Reference in a new issue