mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
batch completions for vllm now works too
This commit is contained in:
parent
e98394f0b5
commit
14fa57c185
21 changed files with 149 additions and 23 deletions
|
@ -693,7 +693,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
logging_obj=logging
|
||||
)
|
||||
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -828,23 +828,68 @@ def completion_with_retries(*args, **kwargs):
|
|||
return retryer(completion, *args, **kwargs)
|
||||
|
||||
|
||||
def batch_completion(*args, **kwargs):
|
||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||
def batch_completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: List = [],
|
||||
function_call: str = "", # optional params
|
||||
temperature: float = 1,
|
||||
top_p: float = 1,
|
||||
n: int = 1,
|
||||
stream: bool = False,
|
||||
stop=None,
|
||||
max_tokens: float = float("inf"),
|
||||
presence_penalty: float = 0,
|
||||
frequency_penalty=0,
|
||||
logit_bias: dict = {},
|
||||
user: str = "",
|
||||
# used by text-bison only
|
||||
top_k=40,
|
||||
custom_llm_provider=None,):
|
||||
args = locals()
|
||||
batch_messages = messages
|
||||
completions = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for message_list in batch_messages:
|
||||
if len(args) > 1:
|
||||
args_modified = list(args)
|
||||
args_modified[1] = message_list
|
||||
future = executor.submit(completion, *args_modified)
|
||||
else:
|
||||
kwargs_modified = dict(kwargs)
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, *args, **kwargs_modified)
|
||||
completions.append(future)
|
||||
model = model
|
||||
custom_llm_provider = None
|
||||
if model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "vllm":
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
top_k=top_k,
|
||||
)
|
||||
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
|
||||
else:
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||
for sub_batch in chunks(batch_messages, 100):
|
||||
for message_list in sub_batch:
|
||||
kwargs_modified = args
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, **kwargs_modified)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
return results
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue