This commit is contained in:
Krrish Dholakia 2023-08-11 09:51:14 -07:00
parent b2cf13bb1b
commit fb285c8c9f
5 changed files with 38 additions and 3 deletions

View file

@ -6,6 +6,7 @@ from copy import deepcopy
import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params
import tiktoken
from concurrent.futures import ThreadPoolExecutor
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
####### ENVIRONMENT VARIABLES ###################
@ -116,8 +117,6 @@ def completion(
messages = messages,
**optional_params
)
if custom_api_base: # reset after call, if a dynamic api base was passsed
openai.api_base = "https://api.openai.com/v1"
elif model in litellm.open_ai_text_completion_models:
openai.api_type = "openai"
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
@ -439,6 +438,25 @@ def completion(
## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
completions = []
with ThreadPoolExecutor() as executor:
for message_list in batch_messages:
if len(args) > 1:
args_modified = list(args)
args_modified[1] = message_list
future = executor.submit(completion, *args_modified)
else:
kwargs_modified = dict(kwargs)
kwargs_modified["messages"] = message_list
future = executor.submit(completion, *args, **kwargs_modified)
completions.append(future)
# Retrieve the results from the futures
results = [future.result() for future in completions]
return results
### EMBEDDING ENDPOINTS ####################
@client
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`

View file

@ -0,0 +1,17 @@
#### What this tests ####
# This tests calling batch_completions by running 100 messages together
import sys, os
import traceback
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
import litellm
from litellm import batch_completion
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(100)]
print(messages[0:5])
print(len(messages))
model = "gpt-3.5-turbo"
result = batch_completion(model=model, messages=messages)
print(result)
print(len(result))

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.380"
version = "0.1.381"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"