forked from phoenix/litellm-mirror
batch completions for vllm now works too
This commit is contained in:
parent
4a263f6ab7
commit
35cf6ef0a1
21 changed files with 149 additions and 23 deletions
BIN
dist/litellm-0.1.549-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.549-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.549.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.549.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.550-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.550-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.550.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.550.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.551-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.551-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.551.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.551.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.552-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.552-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.552.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.552.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.553-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.553-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.553.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.553.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.554-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.554-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.554.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.554.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.555-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-0.1.555-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-0.1.555.tar.gz
vendored
Normal file
BIN
dist/litellm-0.1.555.tar.gz
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -6,7 +6,7 @@ import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
llm = None
|
||||||
class VLLMError(Exception):
|
class VLLMError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
|
@ -16,10 +16,12 @@ class VLLMError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
# check if vllm is installed
|
# check if vllm is installed
|
||||||
def validate_environment():
|
def validate_environment(model: str, llm: any=None):
|
||||||
try:
|
try:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
return LLM, SamplingParams
|
if llm is None:
|
||||||
|
llm = LLM(model=model)
|
||||||
|
return llm, SamplingParams
|
||||||
except:
|
except:
|
||||||
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
|
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
|
||||||
|
|
||||||
|
@ -35,9 +37,8 @@ def completion(
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
LLM, SamplingParams = validate_environment()
|
|
||||||
try:
|
try:
|
||||||
llm = LLM(model=model)
|
llm, SamplingParams = validate_environment(model=model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VLLMError(status_code=0, message=str(e))
|
raise VLLMError(status_code=0, message=str(e))
|
||||||
sampling_params = SamplingParams(**optional_params)
|
sampling_params = SamplingParams(**optional_params)
|
||||||
|
@ -92,6 +93,85 @@ def completion(
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def batch_completions(
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
optional_params=None,
|
||||||
|
custom_prompt_dict={}
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Example usage:
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
from litellm import batch_completion
|
||||||
|
|
||||||
|
|
||||||
|
responses = batch_completion(
|
||||||
|
model="vllm/facebook/opt-125m",
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "good morning? "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what's the time? "
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
global llm
|
||||||
|
try:
|
||||||
|
llm, SamplingParams = validate_environment(model=model, llm=llm)
|
||||||
|
except Exception as e:
|
||||||
|
if "data parallel group is already initialized" in e:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise VLLMError(status_code=0, message=str(e))
|
||||||
|
sampling_params = SamplingParams(**optional_params)
|
||||||
|
prompts = []
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
for message in messages:
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=message
|
||||||
|
)
|
||||||
|
prompts.append(prompt)
|
||||||
|
else:
|
||||||
|
for message in messages:
|
||||||
|
prompt = prompt_factory(model=model, messages=message)
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
final_outputs = []
|
||||||
|
for output in outputs:
|
||||||
|
model_response = ModelResponse()
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
model_response["choices"][0]["message"]["content"] = output.outputs[0].text
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = len(output.prompt_token_ids)
|
||||||
|
completion_tokens = len(output.outputs[0].token_ids)
|
||||||
|
|
||||||
|
model_response["created"] = time.time()
|
||||||
|
model_response["model"] = model
|
||||||
|
model_response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
}
|
||||||
|
final_outputs.append(model_response)
|
||||||
|
return final_outputs
|
||||||
|
|
||||||
def embedding():
|
def embedding():
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -828,19 +828,64 @@ def completion_with_retries(*args, **kwargs):
|
||||||
return retryer(completion, *args, **kwargs)
|
return retryer(completion, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def batch_completion(*args, **kwargs):
|
def batch_completion(
|
||||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
model: str,
|
||||||
|
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
messages: List = [],
|
||||||
|
functions: List = [],
|
||||||
|
function_call: str = "", # optional params
|
||||||
|
temperature: float = 1,
|
||||||
|
top_p: float = 1,
|
||||||
|
n: int = 1,
|
||||||
|
stream: bool = False,
|
||||||
|
stop=None,
|
||||||
|
max_tokens: float = float("inf"),
|
||||||
|
presence_penalty: float = 0,
|
||||||
|
frequency_penalty=0,
|
||||||
|
logit_bias: dict = {},
|
||||||
|
user: str = "",
|
||||||
|
# used by text-bison only
|
||||||
|
top_k=40,
|
||||||
|
custom_llm_provider=None,):
|
||||||
|
args = locals()
|
||||||
|
batch_messages = messages
|
||||||
completions = []
|
completions = []
|
||||||
with ThreadPoolExecutor() as executor:
|
model = model
|
||||||
for message_list in batch_messages:
|
custom_llm_provider = None
|
||||||
if len(args) > 1:
|
if model.split("/", 1)[0] in litellm.provider_list:
|
||||||
args_modified = list(args)
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
args_modified[1] = message_list
|
model = model.split("/", 1)[1]
|
||||||
future = executor.submit(completion, *args_modified)
|
if custom_llm_provider == "vllm":
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
functions=functions,
|
||||||
|
function_call=function_call,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
n=n,
|
||||||
|
stream=stream,
|
||||||
|
stop=stop,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
user=user,
|
||||||
|
# params to identify the model
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params)
|
||||||
else:
|
else:
|
||||||
kwargs_modified = dict(kwargs)
|
def chunks(lst, n):
|
||||||
|
"""Yield successive n-sized chunks from lst."""
|
||||||
|
for i in range(0, len(lst), n):
|
||||||
|
yield lst[i:i + n]
|
||||||
|
with ThreadPoolExecutor(max_workers=100) as executor:
|
||||||
|
for sub_batch in chunks(batch_messages, 100):
|
||||||
|
for message_list in sub_batch:
|
||||||
|
kwargs_modified = args
|
||||||
kwargs_modified["messages"] = message_list
|
kwargs_modified["messages"] = message_list
|
||||||
future = executor.submit(completion, *args, **kwargs_modified)
|
future = executor.submit(completion, **kwargs_modified)
|
||||||
completions.append(future)
|
completions.append(future)
|
||||||
|
|
||||||
# Retrieve the results from the futures
|
# Retrieve the results from the futures
|
||||||
|
|
|
@ -9,10 +9,11 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import batch_completion
|
from litellm import batch_completion
|
||||||
|
litellm.set_verbose=True
|
||||||
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)]
|
messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)]
|
||||||
print(messages[0:5])
|
print(messages[0:5])
|
||||||
print(len(messages))
|
print(len(messages))
|
||||||
|
# model = "vllm/facebook/opt-125m"
|
||||||
model = "gpt-3.5-turbo"
|
model = "gpt-3.5-turbo"
|
||||||
|
|
||||||
result = batch_completion(model=model, messages=messages)
|
result = batch_completion(model=model, messages=messages)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.549"
|
version = "0.1.555"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue