This commit is contained in:
Krrish Dholakia 2023-08-11 09:51:14 -07:00
parent b2cf13bb1b
commit fb285c8c9f
5 changed files with 38 additions and 3 deletions

View file

@ -6,6 +6,7 @@ from copy import deepcopy
import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params
import tiktoken
from concurrent.futures import ThreadPoolExecutor
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
####### ENVIRONMENT VARIABLES ###################
@ -116,8 +117,6 @@ def completion(
messages = messages,
**optional_params
)
if custom_api_base: # reset after call, if a dynamic api base was passsed
openai.api_base = "https://api.openai.com/v1"
elif model in litellm.open_ai_text_completion_models:
openai.api_type = "openai"
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
@ -439,6 +438,25 @@ def completion(
## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
completions = []
with ThreadPoolExecutor() as executor:
for message_list in batch_messages:
if len(args) > 1:
args_modified = list(args)
args_modified[1] = message_list
future = executor.submit(completion, *args_modified)
else:
kwargs_modified = dict(kwargs)
kwargs_modified["messages"] = message_list
future = executor.submit(completion, *args, **kwargs_modified)
completions.append(future)
# Retrieve the results from the futures
results = [future.result() for future in completions]
return results
### EMBEDDING ENDPOINTS ####################
@client
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`