diff --git a/litellm/main.py b/litellm/main.py index c131838970..b9acd6e4af 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1131,6 +1131,42 @@ def batch_completion( results = [future.result() for future in completions] return results +# send one request to multiple models +# return as soon as one of the llms responds +def batch_completion_models(*args, **kwargs): + """ + Send a request to multiple language models concurrently and return the response + as soon as one of the models responds. + + Args: + *args: Variable-length positional arguments passed to the completion function. + **kwargs: Additional keyword arguments: + - models (str or list of str): The language models to send requests to. + - Other keyword arguments to be passed to the completion function. + + Returns: + str or None: The response from one of the language models, or None if no response is received. + + Note: + This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models. + It sends requests concurrently and returns the response from the first model that responds. + """ + import concurrent + if "model" in kwargs: + kwargs.pop("model") + if "models" in kwargs: + models = kwargs["models"] + kwargs.pop("models") + with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: + futures = [executor.submit(completion, *args, model=model, **kwargs) for model in models] + + for future in concurrent.futures.as_completed(futures): + if future.result() is not None: + return future.result() + + return None # If no response is received from any model + + ### EMBEDDING ENDPOINTS #################### @client @timeout( # type: ignore