From eeeb95a6acd132b38af020de6883a02d3a20a9fd Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 20 Sep 2023 08:17:46 -0700 Subject: [PATCH] add batch testing --- litellm/main.py | 25 +++++++++++++++++-------- litellm/tests/test_batch_completions.py | 5 +++-- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index b65a30962..a476e8228 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1208,22 +1208,31 @@ def batch_completion_models_all_responses(*args, **kwargs): It sends requests concurrently and collects responses from all models that respond. """ import concurrent.futures + + # ANSI escape codes for colored output + GREEN = "\033[92m" + RED = "\033[91m" + RESET = "\033[0m" + if "model" in kwargs: kwargs.pop("model") if "models" in kwargs: models = kwargs["models"] kwargs.pop("models") - with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: - futures = [executor.submit(completion, *args, model=model, **kwargs) for model in models] - # Collect responses from all models that respond - responses = [future.result() for future in concurrent.futures.as_completed(futures) if future.result() is not None] - - return responses - - return [] # If no response is received from any model, return an empty list + responses = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: + for idx, model in enumerate(models): + print(f"{GREEN}LiteLLM: Making request to model: {model}{RESET}") + future = executor.submit(completion, *args, model=model, **kwargs) + if future.result() is not None: + responses.append(future.result()) + print(f"{GREEN}LiteLLM: Model {model} returned response{RESET}") + else: + print(f"{RED}LiteLLM: Model {model } did not return a response{RESET}") + return responses ### EMBEDDING ENDPOINTS #################### @client diff --git a/litellm/tests/test_batch_completions.py b/litellm/tests/test_batch_completions.py index 4994fa9e4..27b2bd847 100644 --- a/litellm/tests/test_batch_completions.py +++ b/litellm/tests/test_batch_completions.py @@ -38,10 +38,11 @@ def test_batch_completions_models(): def test_batch_completion_models_all_responses(): responses = batch_completion_models_all_responses( models=["gpt-3.5-turbo", "claude-instant-1.2", "command-nightly"], - messages=[{"role": "user", "content": "Hey, how's it going"}], - max_tokens=5 + messages=[{"role": "user", "content": "write a poem"}], + max_tokens=500 ) print(responses) + assert(len(responses) == 3) # test_batch_completion_models_all_responses() # def test_batch_completions():