diff --git a/litellm/main.py b/litellm/main.py index 5da2b4a52e..cb197aef89 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -680,6 +680,7 @@ def completion( "region_name", "allowed_model_region", "model_config", + "fastest_response", ] default_params = openai_params + litellm_params diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6efcb2a702..ee1cd7a642 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -415,6 +415,7 @@ def get_custom_headers( api_base: Optional[str] = None, version: Optional[str] = None, model_region: Optional[str] = None, + fastest_response_batch_completion: Optional[bool] = None, ) -> dict: exclude_values = {"", None} headers = { @@ -425,6 +426,11 @@ def get_custom_headers( "x-litellm-model-region": model_region, "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), + "x-litellm-fastest_response_batch_completion": ( + str(fastest_response_batch_completion) + if fastest_response_batch_completion is not None + else None + ), } try: return { @@ -4035,7 +4041,17 @@ async def chat_completion( elif "," in data["model"] and llm_router is not None: _models_csv_string = data.pop("model") _models = _models_csv_string.split(",") - tasks.append(llm_router.abatch_completion(models=_models, **data)) + if ( + data.get("fastest_response", None) is not None + and data["fastest_response"] == True + ): + tasks.append( + llm_router.abatch_completion_fastest_response( + models=_models, **data + ) + ) + else: + tasks.append(llm_router.abatch_completion(models=_models, **data)) elif "user_config" in data: # initialize a new router instance. make request using this Router router_config = data.pop("user_config") @@ -4085,6 +4101,9 @@ async def chat_completion( model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" api_base = hidden_params.get("api_base", None) or "" + fastest_response_batch_completion = hidden_params.get( + "fastest_response_batch_completion", None + ) # Post Call Processing if llm_router is not None: @@ -4101,6 +4120,7 @@ async def chat_completion( api_base=api_base, version=version, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + fastest_response_batch_completion=fastest_response_batch_completion, ) selected_data_generator = select_data_generator( response=response, @@ -4121,6 +4141,7 @@ async def chat_completion( api_base=api_base, version=version, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + fastest_response_batch_completion=fastest_response_batch_completion, ) ) diff --git a/litellm/router.py b/litellm/router.py index 631360da6e..b87d0dded0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -804,9 +804,16 @@ class Router: pending_tasks.append(task) responses = await asyncio.gather(*_tasks, return_exceptions=True) - if isinstance(responses[0], Exception): + if isinstance(responses[0], Exception) or isinstance( + responses[0], BaseException + ): raise responses[0] - return responses[0] # return first value from list + _response: Union[ModelResponse, CustomStreamWrapper] = responses[ + 0 + ] # return first value from list + + _response._hidden_params["fastest_response_batch_completion"] = True + return _response def image_generation(self, prompt: str, model: str, **kwargs): try: