diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 79e48505e8..55f1f9c754 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index e145581291..e9b3e7f1c8 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index fa88e3ca82..e5291a9b73 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -6,6 +6,7 @@ from copy import deepcopy import litellm from litellm import client, logging, exception_type, timeout, get_optional_params import tiktoken +from concurrent.futures import ThreadPoolExecutor encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import get_secret, install_and_import, CustomStreamWrapper ####### ENVIRONMENT VARIABLES ################### @@ -116,8 +117,6 @@ def completion( messages = messages, **optional_params ) - if custom_api_base: # reset after call, if a dynamic api base was passsed - openai.api_base = "https://api.openai.com/v1" elif model in litellm.open_ai_text_completion_models: openai.api_type = "openai" openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1" @@ -439,6 +438,25 @@ def completion( ## Map to OpenAI Exception raise exception_type(model=model, original_exception=e) +def batch_completion(*args, **kwargs): + batch_messages = args[1] if len(args) > 1 else kwargs.get("messages") + completions = [] + with ThreadPoolExecutor() as executor: + for message_list in batch_messages: + if len(args) > 1: + args_modified = list(args) + args_modified[1] = message_list + future = executor.submit(completion, *args_modified) + else: + kwargs_modified = dict(kwargs) + kwargs_modified["messages"] = message_list + future = executor.submit(completion, *args, **kwargs_modified) + completions.append(future) + + # Retrieve the results from the futures + results = [future.result() for future in completions] + return results + ### EMBEDDING ENDPOINTS #################### @client @timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` diff --git a/litellm/tests/test_batch_completions.py b/litellm/tests/test_batch_completions.py new file mode 100644 index 0000000000..ca041ec95b --- /dev/null +++ b/litellm/tests/test_batch_completions.py @@ -0,0 +1,17 @@ +#### What this tests #### +# This tests calling batch_completions by running 100 messages together + +import sys, os +import traceback +sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +import litellm +from litellm import batch_completion + +messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(100)] +print(messages[0:5]) +print(len(messages)) +model = "gpt-3.5-turbo" + +result = batch_completion(model=model, messages=messages) +print(result) +print(len(result)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d5eb3177c3..9fe3e3433e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.380" +version = "0.1.381" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"