feat(proxy_server.py): enable batch completion fastest response calls on proxy

introduces new `fastest_response` flag for enabling the call
This commit is contained in:
Krrish Dholakia 2024-05-28 20:09:31 -07:00
parent ecd182eb6a
commit 20106715d5
3 changed files with 32 additions and 3 deletions

View file

@ -415,6 +415,7 @@ def get_custom_headers(
api_base: Optional[str] = None,
version: Optional[str] = None,
model_region: Optional[str] = None,
fastest_response_batch_completion: Optional[bool] = None,
) -> dict:
exclude_values = {"", None}
headers = {
@ -425,6 +426,11 @@ def get_custom_headers(
"x-litellm-model-region": model_region,
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None
else None
),
}
try:
return {
@ -4035,7 +4041,17 @@ async def chat_completion(
elif "," in data["model"] and llm_router is not None:
_models_csv_string = data.pop("model")
_models = _models_csv_string.split(",")
tasks.append(llm_router.abatch_completion(models=_models, **data))
if (
data.get("fastest_response", None) is not None
and data["fastest_response"] == True
):
tasks.append(
llm_router.abatch_completion_fastest_response(
models=_models, **data
)
)
else:
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")
@ -4085,6 +4101,9 @@ async def chat_completion(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
# Post Call Processing
if llm_router is not None:
@ -4101,6 +4120,7 @@ async def chat_completion(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
)
selected_data_generator = select_data_generator(
response=response,
@ -4121,6 +4141,7 @@ async def chat_completion(
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
)
)