batch completions for vllm now works too

This commit is contained in:
Krrish Dholakia 2023-09-06 18:52:34 -07:00
parent 4a263f6ab7
commit 35cf6ef0a1
21 changed files with 149 additions and 23 deletions

BIN
dist/litellm-0.1.549-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.549.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.550-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.550.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.551-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.551.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.552-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.552.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.553-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.553.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.554-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.554.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.555-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.555.tar.gz vendored Normal file

Binary file not shown.

View file

@ -6,7 +6,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None
class VLLMError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -16,10 +16,12 @@ class VLLMError(Exception):
) # Call the base class constructor with the parameters it needs
# check if vllm is installed
def validate_environment():
def validate_environment(model: str, llm: any=None):
try:
from vllm import LLM, SamplingParams
return LLM, SamplingParams
if llm is None:
llm = LLM(model=model)
return llm, SamplingParams
except:
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
@ -35,9 +37,8 @@ def completion(
litellm_params=None,
logger_fn=None,
):
LLM, SamplingParams = validate_environment()
try:
llm = LLM(model=model)
llm, SamplingParams = validate_environment(model=model)
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
@ -92,6 +93,85 @@ def completion(
}
return model_response
def batch_completions(
model: str,
messages: list,
optional_params=None,
custom_prompt_dict={}
):
"""
Example usage:
import litellm
import os
from litellm import batch_completion
responses = batch_completion(
model="vllm/facebook/opt-125m",
messages = [
[
{
"role": "user",
"content": "good morning? "
}
],
[
{
"role": "user",
"content": "what's the time? "
}
]
]
)
"""
global llm
try:
llm, SamplingParams = validate_environment(model=model, llm=llm)
except Exception as e:
if "data parallel group is already initialized" in e:
pass
else:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
prompts = []
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
for message in messages:
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=message
)
prompts.append(prompt)
else:
for message in messages:
prompt = prompt_factory(model=model, messages=message)
prompts.append(prompt)
outputs = llm.generate(prompts, sampling_params)
final_outputs = []
for output in outputs:
model_response = ModelResponse()
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
## CALCULATING USAGE
prompt_tokens = len(output.prompt_token_ids)
completion_tokens = len(output.outputs[0].token_ids)
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
final_outputs.append(model_response)
return final_outputs
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -828,19 +828,64 @@ def completion_with_retries(*args, **kwargs):
return retryer(completion, *args, **kwargs)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
temperature: float = 1,
top_p: float = 1,
n: int = 1,
stream: bool = False,
stop=None,
max_tokens: float = float("inf"),
presence_penalty: float = 0,
frequency_penalty=0,
logit_bias: dict = {},
user: str = "",
# used by text-bison only
top_k=40,
custom_llm_provider=None,):
args = locals()
batch_messages = messages
completions = []
with ThreadPoolExecutor() as executor:
for message_list in batch_messages:
if len(args) > 1:
args_modified = list(args)
args_modified[1] = message_list
future = executor.submit(completion, *args_modified)
model = model
custom_llm_provider = None
if model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if custom_llm_provider == "vllm":
optional_params = get_optional_params(
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
top_k=top_k,
)
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
else:
kwargs_modified = dict(kwargs)
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
with ThreadPoolExecutor(max_workers=100) as executor:
for sub_batch in chunks(batch_messages, 100):
for message_list in sub_batch:
kwargs_modified = args
kwargs_modified["messages"] = message_list
future = executor.submit(completion, *args, **kwargs_modified)
future = executor.submit(completion, **kwargs_modified)
completions.append(future)
# Retrieve the results from the futures

View file

@ -9,10 +9,11 @@ sys.path.insert(
) # Adds the parent directory to the system path
import litellm
from litellm import batch_completion
litellm.set_verbose=True
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)]
print(messages[0:5])
print(len(messages))
# model = "vllm/facebook/opt-125m"
model = "gpt-3.5-turbo"
result = batch_completion(model=model, messages=messages)

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.549"
version = "0.1.555"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"