batch completions for vllm now works too

This commit is contained in:
Krrish Dholakia 2023-09-06 18:52:34 -07:00
parent e98394f0b5
commit 14fa57c185
21 changed files with 149 additions and 23 deletions

View file

@ -693,7 +693,7 @@ def completion(
encoding=encoding,
logging_obj=logging
)
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
# don't try to access stream object,
response = CustomStreamWrapper(
@ -828,23 +828,68 @@ def completion_with_retries(*args, **kwargs):
return retryer(completion, *args, **kwargs)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
temperature: float = 1,
top_p: float = 1,
n: int = 1,
stream: bool = False,
stop=None,
max_tokens: float = float("inf"),
presence_penalty: float = 0,
frequency_penalty=0,
logit_bias: dict = {},
user: str = "",
# used by text-bison only
top_k=40,
custom_llm_provider=None,):
args = locals()
batch_messages = messages
completions = []
with ThreadPoolExecutor() as executor:
for message_list in batch_messages:
if len(args) > 1:
args_modified = list(args)
args_modified[1] = message_list
future = executor.submit(completion, *args_modified)
else:
kwargs_modified = dict(kwargs)
kwargs_modified["messages"] = message_list
future = executor.submit(completion, *args, **kwargs_modified)
completions.append(future)
model = model
custom_llm_provider = None
if model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if custom_llm_provider == "vllm":
optional_params = get_optional_params(
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
top_k=top_k,
)
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
else:
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
with ThreadPoolExecutor(max_workers=100) as executor:
for sub_batch in chunks(batch_messages, 100):
for message_list in sub_batch:
kwargs_modified = args
kwargs_modified["messages"] = message_list
future = executor.submit(completion, **kwargs_modified)
completions.append(future)
# Retrieve the results from the futures
results = [future.result() for future in completions]
# Retrieve the results from the futures
results = [future.result() for future in completions]
return results