mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
update
This commit is contained in:
parent
b2cf13bb1b
commit
fb285c8c9f
5 changed files with 38 additions and 3 deletions
Binary file not shown.
Binary file not shown.
|
@ -6,6 +6,7 @@ from copy import deepcopy
|
|||
import litellm
|
||||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
|
@ -116,8 +117,6 @@ def completion(
|
|||
messages = messages,
|
||||
**optional_params
|
||||
)
|
||||
if custom_api_base: # reset after call, if a dynamic api base was passsed
|
||||
openai.api_base = "https://api.openai.com/v1"
|
||||
elif model in litellm.open_ai_text_completion_models:
|
||||
openai.api_type = "openai"
|
||||
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
|
||||
|
@ -439,6 +438,25 @@ def completion(
|
|||
## Map to OpenAI Exception
|
||||
raise exception_type(model=model, original_exception=e)
|
||||
|
||||
def batch_completion(*args, **kwargs):
|
||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||
completions = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for message_list in batch_messages:
|
||||
if len(args) > 1:
|
||||
args_modified = list(args)
|
||||
args_modified[1] = message_list
|
||||
future = executor.submit(completion, *args_modified)
|
||||
else:
|
||||
kwargs_modified = dict(kwargs)
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, *args, **kwargs_modified)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
return results
|
||||
|
||||
### EMBEDDING ENDPOINTS ####################
|
||||
@client
|
||||
@timeout(60) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
|
||||
|
|
17
litellm/tests/test_batch_completions.py
Normal file
17
litellm/tests/test_batch_completions.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
#### What this tests ####
|
||||
# This tests calling batch_completions by running 100 messages together
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import batch_completion
|
||||
|
||||
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(100)]
|
||||
print(messages[0:5])
|
||||
print(len(messages))
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
result = batch_completion(model=model, messages=messages)
|
||||
print(result)
|
||||
print(len(result))
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.380"
|
||||
version = "0.1.381"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue