mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
batch completions for vllm now works too
This commit is contained in:
parent
4a263f6ab7
commit
35cf6ef0a1
21 changed files with 149 additions and 23 deletions
BIN
dist/litellm-0.1.549-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.549-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.549.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.549.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.550-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.550-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.550.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.550.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.551-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.551-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.551.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.551.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.552-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.552-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.552.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.552.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.553-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.553-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.553.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.553.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.554-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.554-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.554.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.554.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.555-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.555-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.555.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.555.tar.gz
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -6,7 +6,7 @@ import time
|
|||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
llm = None
|
||||
class VLLMError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -16,10 +16,12 @@ class VLLMError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
# check if vllm is installed
|
||||
def validate_environment():
|
||||
def validate_environment(model: str, llm: any=None):
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
return LLM, SamplingParams
|
||||
if llm is None:
|
||||
llm = LLM(model=model)
|
||||
return llm, SamplingParams
|
||||
except:
|
||||
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
|
||||
|
||||
|
@ -35,9 +37,8 @@ def completion(
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
LLM, SamplingParams = validate_environment()
|
||||
try:
|
||||
llm = LLM(model=model)
|
||||
llm, SamplingParams = validate_environment(model=model)
|
||||
except Exception as e:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
|
@ -92,6 +93,85 @@ def completion(
|
|||
}
|
||||
return model_response
|
||||
|
||||
def batch_completions(
|
||||
model: str,
|
||||
messages: list,
|
||||
optional_params=None,
|
||||
custom_prompt_dict={}
|
||||
):
|
||||
"""
|
||||
Example usage:
|
||||
import litellm
|
||||
import os
|
||||
from litellm import batch_completion
|
||||
|
||||
|
||||
responses = batch_completion(
|
||||
model="vllm/facebook/opt-125m",
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "good morning? "
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what's the time? "
|
||||
}
|
||||
]
|
||||
]
|
||||
)
|
||||
"""
|
||||
global llm
|
||||
try:
|
||||
llm, SamplingParams = validate_environment(model=model, llm=llm)
|
||||
except Exception as e:
|
||||
if "data parallel group is already initialized" in e:
|
||||
pass
|
||||
else:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
prompts = []
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
for message in messages:
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=message
|
||||
)
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
for message in messages:
|
||||
prompt = prompt_factory(model=model, messages=message)
|
||||
prompts.append(prompt)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
final_outputs = []
|
||||
for output in outputs:
|
||||
model_response = ModelResponse()
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(output.prompt_token_ids)
|
||||
completion_tokens = len(output.outputs[0].token_ids)
|
||||
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
final_outputs.append(model_response)
|
||||
return final_outputs
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -693,7 +693,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
logging_obj=logging
|
||||
)
|
||||
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -828,23 +828,68 @@ def completion_with_retries(*args, **kwargs):
|
|||
return retryer(completion, *args, **kwargs)
|
||||
|
||||
|
||||
def batch_completion(*args, **kwargs):
|
||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||
def batch_completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: List = [],
|
||||
function_call: str = "", # optional params
|
||||
temperature: float = 1,
|
||||
top_p: float = 1,
|
||||
n: int = 1,
|
||||
stream: bool = False,
|
||||
stop=None,
|
||||
max_tokens: float = float("inf"),
|
||||
presence_penalty: float = 0,
|
||||
frequency_penalty=0,
|
||||
logit_bias: dict = {},
|
||||
user: str = "",
|
||||
# used by text-bison only
|
||||
top_k=40,
|
||||
custom_llm_provider=None,):
|
||||
args = locals()
|
||||
batch_messages = messages
|
||||
completions = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for message_list in batch_messages:
|
||||
if len(args) > 1:
|
||||
args_modified = list(args)
|
||||
args_modified[1] = message_list
|
||||
future = executor.submit(completion, *args_modified)
|
||||
else:
|
||||
kwargs_modified = dict(kwargs)
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, *args, **kwargs_modified)
|
||||
completions.append(future)
|
||||
model = model
|
||||
custom_llm_provider = None
|
||||
if model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "vllm":
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
top_k=top_k,
|
||||
)
|
||||
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
|
||||
else:
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i + n]
|
||||
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||
for sub_batch in chunks(batch_messages, 100):
|
||||
for message_list in sub_batch:
|
||||
kwargs_modified = args
|
||||
kwargs_modified["messages"] = message_list
|
||||
future = executor.submit(completion, **kwargs_modified)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
# Retrieve the results from the futures
|
||||
results = [future.result() for future in completions]
|
||||
return results
|
||||
|
||||
|
||||
|
|
|
@ -9,10 +9,11 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import batch_completion
|
||||
|
||||
litellm.set_verbose=True
|
||||
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)]
|
||||
print(messages[0:5])
|
||||
print(len(messages))
|
||||
# model = "vllm/facebook/opt-125m"
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
result = batch_completion(model=model, messages=messages)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.549"
|
||||
version = "0.1.555"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue