diff --git a/litellm/main.py b/litellm/main.py index b65a309620..a476e82289 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1208,22 +1208,31 @@ def batch_completion_models_all_responses(*args, **kwargs): It sends requests concurrently and collects responses from all models that respond. """ import concurrent.futures + + # ANSI escape codes for colored output + GREEN = "\033[92m" + RED = "\033[91m" + RESET = "\033[0m" + if "model" in kwargs: kwargs.pop("model") if "models" in kwargs: models = kwargs["models"] kwargs.pop("models") - with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: - futures = [executor.submit(completion, *args, model=model, **kwargs) for model in models] - # Collect responses from all models that respond - responses = [future.result() for future in concurrent.futures.as_completed(futures) if future.result() is not None] - - return responses - - return [] # If no response is received from any model, return an empty list + responses = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: + for idx, model in enumerate(models): + print(f"{GREEN}LiteLLM: Making request to model: {model}{RESET}") + future = executor.submit(completion, *args, model=model, **kwargs) + if future.result() is not None: + responses.append(future.result()) + print(f"{GREEN}LiteLLM: Model {model} returned response{RESET}") + else: + print(f"{RED}LiteLLM: Model {model } did not return a response{RESET}") + return responses ### EMBEDDING ENDPOINTS #################### @client diff --git a/litellm/tests/test_batch_completions.py b/litellm/tests/test_batch_completions.py index 4994fa9e40..27b2bd8476 100644 --- a/litellm/tests/test_batch_completions.py +++ b/litellm/tests/test_batch_completions.py @@ -38,10 +38,11 @@ def test_batch_completions_models(): def test_batch_completion_models_all_responses(): responses = batch_completion_models_all_responses( models=["gpt-3.5-turbo", "claude-instant-1.2", "command-nightly"], - messages=[{"role": "user", "content": "Hey, how's it going"}], - max_tokens=5 + messages=[{"role": "user", "content": "write a poem"}], + max_tokens=500 ) print(responses) + assert(len(responses) == 3) # test_batch_completion_models_all_responses() # def test_batch_completions():