feat(proxy_server.py): enable batch completion fastest response calls on proxy

introduces new `fastest_response` flag for enabling the call
This commit is contained in:
Krrish Dholakia 2024-05-28 20:09:31 -07:00
parent ecd182eb6a
commit 20106715d5
3 changed files with 32 additions and 3 deletions

View file

@ -680,6 +680,7 @@ def completion(
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
"model_config", "model_config",
"fastest_response",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params

View file

@ -415,6 +415,7 @@ def get_custom_headers(
api_base: Optional[str] = None, api_base: Optional[str] = None,
version: Optional[str] = None, version: Optional[str] = None,
model_region: Optional[str] = None, model_region: Optional[str] = None,
fastest_response_batch_completion: Optional[bool] = None,
) -> dict: ) -> dict:
exclude_values = {"", None} exclude_values = {"", None}
headers = { headers = {
@ -425,6 +426,11 @@ def get_custom_headers(
"x-litellm-model-region": model_region, "x-litellm-model-region": model_region,
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None
else None
),
} }
try: try:
return { return {
@ -4035,6 +4041,16 @@ async def chat_completion(
elif "," in data["model"] and llm_router is not None: elif "," in data["model"] and llm_router is not None:
_models_csv_string = data.pop("model") _models_csv_string = data.pop("model")
_models = _models_csv_string.split(",") _models = _models_csv_string.split(",")
if (
data.get("fastest_response", None) is not None
and data["fastest_response"] == True
):
tasks.append(
llm_router.abatch_completion_fastest_response(
models=_models, **data
)
)
else:
tasks.append(llm_router.abatch_completion(models=_models, **data)) tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data: elif "user_config" in data:
# initialize a new router instance. make request using this Router # initialize a new router instance. make request using this Router
@ -4085,6 +4101,9 @@ async def chat_completion(
model_id = hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
# Post Call Processing # Post Call Processing
if llm_router is not None: if llm_router is not None:
@ -4101,6 +4120,7 @@ async def chat_completion(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -4121,6 +4141,7 @@ async def chat_completion(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
) )
) )

View file

@ -804,9 +804,16 @@ class Router:
pending_tasks.append(task) pending_tasks.append(task)
responses = await asyncio.gather(*_tasks, return_exceptions=True) responses = await asyncio.gather(*_tasks, return_exceptions=True)
if isinstance(responses[0], Exception): if isinstance(responses[0], Exception) or isinstance(
responses[0], BaseException
):
raise responses[0] raise responses[0]
return responses[0] # return first value from list _response: Union[ModelResponse, CustomStreamWrapper] = responses[
0
] # return first value from list
_response._hidden_params["fastest_response_batch_completion"] = True
return _response
def image_generation(self, prompt: str, model: str, **kwargs): def image_generation(self, prompt: str, model: str, **kwargs):
try: try: