diff --git a/dist/litellm-0.1.549-py3-none-any.whl b/dist/litellm-0.1.549-py3-none-any.whl new file mode 100644 index 000000000..6cccd853a Binary files /dev/null and b/dist/litellm-0.1.549-py3-none-any.whl differ diff --git a/dist/litellm-0.1.549.tar.gz b/dist/litellm-0.1.549.tar.gz new file mode 100644 index 000000000..991d0bf6e Binary files /dev/null and b/dist/litellm-0.1.549.tar.gz differ diff --git a/dist/litellm-0.1.550-py3-none-any.whl b/dist/litellm-0.1.550-py3-none-any.whl new file mode 100644 index 000000000..0cc200c98 Binary files /dev/null and b/dist/litellm-0.1.550-py3-none-any.whl differ diff --git a/dist/litellm-0.1.550.tar.gz b/dist/litellm-0.1.550.tar.gz new file mode 100644 index 000000000..eae89ccba Binary files /dev/null and b/dist/litellm-0.1.550.tar.gz differ diff --git a/dist/litellm-0.1.551-py3-none-any.whl b/dist/litellm-0.1.551-py3-none-any.whl new file mode 100644 index 000000000..a3488bcbf Binary files /dev/null and b/dist/litellm-0.1.551-py3-none-any.whl differ diff --git a/dist/litellm-0.1.551.tar.gz b/dist/litellm-0.1.551.tar.gz new file mode 100644 index 000000000..74d626c92 Binary files /dev/null and b/dist/litellm-0.1.551.tar.gz differ diff --git a/dist/litellm-0.1.552-py3-none-any.whl b/dist/litellm-0.1.552-py3-none-any.whl new file mode 100644 index 000000000..b1a72ce4d Binary files /dev/null and b/dist/litellm-0.1.552-py3-none-any.whl differ diff --git a/dist/litellm-0.1.552.tar.gz b/dist/litellm-0.1.552.tar.gz new file mode 100644 index 000000000..755739f72 Binary files /dev/null and b/dist/litellm-0.1.552.tar.gz differ diff --git a/dist/litellm-0.1.553-py3-none-any.whl b/dist/litellm-0.1.553-py3-none-any.whl new file mode 100644 index 000000000..78b762088 Binary files /dev/null and b/dist/litellm-0.1.553-py3-none-any.whl differ diff --git a/dist/litellm-0.1.553.tar.gz b/dist/litellm-0.1.553.tar.gz new file mode 100644 index 000000000..e6f006359 Binary files /dev/null and b/dist/litellm-0.1.553.tar.gz differ diff --git a/dist/litellm-0.1.554-py3-none-any.whl b/dist/litellm-0.1.554-py3-none-any.whl new file mode 100644 index 000000000..bada34560 Binary files /dev/null and b/dist/litellm-0.1.554-py3-none-any.whl differ diff --git a/dist/litellm-0.1.554.tar.gz b/dist/litellm-0.1.554.tar.gz new file mode 100644 index 000000000..f562c135e Binary files /dev/null and b/dist/litellm-0.1.554.tar.gz differ diff --git a/dist/litellm-0.1.555-py3-none-any.whl b/dist/litellm-0.1.555-py3-none-any.whl new file mode 100644 index 000000000..8534dc8c6 Binary files /dev/null and b/dist/litellm-0.1.555-py3-none-any.whl differ diff --git a/dist/litellm-0.1.555.tar.gz b/dist/litellm-0.1.555.tar.gz new file mode 100644 index 000000000..6d400f3e0 Binary files /dev/null and b/dist/litellm-0.1.555.tar.gz differ diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index cc9f665d8..c95f8ff24 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 1c43da22c..b753e814f 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index a67a6e4e9..5b4f9b84e 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/vllm.py b/litellm/llms/vllm.py index fc803ad59..0f446ae81 100644 --- a/litellm/llms/vllm.py +++ b/litellm/llms/vllm.py @@ -6,7 +6,7 @@ import time from typing import Callable from litellm.utils import ModelResponse from .prompt_templates.factory import prompt_factory, custom_prompt - +llm = None class VLLMError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -16,10 +16,12 @@ class VLLMError(Exception): ) # Call the base class constructor with the parameters it needs # check if vllm is installed -def validate_environment(): +def validate_environment(model: str, llm: any=None): try: from vllm import LLM, SamplingParams - return LLM, SamplingParams + if llm is None: + llm = LLM(model=model) + return llm, SamplingParams except: raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.") @@ -35,9 +37,8 @@ def completion( litellm_params=None, logger_fn=None, ): - LLM, SamplingParams = validate_environment() try: - llm = LLM(model=model) + llm, SamplingParams = validate_environment(model=model) except Exception as e: raise VLLMError(status_code=0, message=str(e)) sampling_params = SamplingParams(**optional_params) @@ -92,6 +93,85 @@ def completion( } return model_response +def batch_completions( + model: str, + messages: list, + optional_params=None, + custom_prompt_dict={} +): + """ + Example usage: + import litellm + import os + from litellm import batch_completion + + + responses = batch_completion( + model="vllm/facebook/opt-125m", + messages = [ + [ + { + "role": "user", + "content": "good morning? " + } + ], + [ + { + "role": "user", + "content": "what's the time? " + } + ] + ] + ) + """ + global llm + try: + llm, SamplingParams = validate_environment(model=model, llm=llm) + except Exception as e: + if "data parallel group is already initialized" in e: + pass + else: + raise VLLMError(status_code=0, message=str(e)) + sampling_params = SamplingParams(**optional_params) + prompts = [] + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + for message in messages: + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=message + ) + prompts.append(prompt) + else: + for message in messages: + prompt = prompt_factory(model=model, messages=message) + prompts.append(prompt) + + outputs = llm.generate(prompts, sampling_params) + + final_outputs = [] + for output in outputs: + model_response = ModelResponse() + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = output.outputs[0].text + + ## CALCULATING USAGE + prompt_tokens = len(output.prompt_token_ids) + completion_tokens = len(output.outputs[0].token_ids) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + final_outputs.append(model_response) + return final_outputs + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/main.py b/litellm/main.py index a28a235a5..2e480b2dc 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -693,7 +693,7 @@ def completion( encoding=encoding, logging_obj=logging ) - + if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] # don't try to access stream object, response = CustomStreamWrapper( @@ -828,23 +828,68 @@ def completion_with_retries(*args, **kwargs): return retryer(completion, *args, **kwargs) -def batch_completion(*args, **kwargs): - batch_messages = args[1] if len(args) > 1 else kwargs.get("messages") +def batch_completion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + functions: List = [], + function_call: str = "", # optional params + temperature: float = 1, + top_p: float = 1, + n: int = 1, + stream: bool = False, + stop=None, + max_tokens: float = float("inf"), + presence_penalty: float = 0, + frequency_penalty=0, + logit_bias: dict = {}, + user: str = "", + # used by text-bison only + top_k=40, + custom_llm_provider=None,): + args = locals() + batch_messages = messages completions = [] - with ThreadPoolExecutor() as executor: - for message_list in batch_messages: - if len(args) > 1: - args_modified = list(args) - args_modified[1] = message_list - future = executor.submit(completion, *args_modified) - else: - kwargs_modified = dict(kwargs) - kwargs_modified["messages"] = message_list - future = executor.submit(completion, *args, **kwargs_modified) - completions.append(future) + model = model + custom_llm_provider = None + if model.split("/", 1)[0] in litellm.provider_list: + custom_llm_provider = model.split("/", 1)[0] + model = model.split("/", 1)[1] + if custom_llm_provider == "vllm": + optional_params = get_optional_params( + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + top_k=top_k, + ) + results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params) + else: + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + with ThreadPoolExecutor(max_workers=100) as executor: + for sub_batch in chunks(batch_messages, 100): + for message_list in sub_batch: + kwargs_modified = args + kwargs_modified["messages"] = message_list + future = executor.submit(completion, **kwargs_modified) + completions.append(future) - # Retrieve the results from the futures - results = [future.result() for future in completions] + # Retrieve the results from the futures + results = [future.result() for future in completions] return results diff --git a/litellm/tests/test_batch_completions.py b/litellm/tests/test_batch_completions.py index a136351ba..f11db7e2b 100644 --- a/litellm/tests/test_batch_completions.py +++ b/litellm/tests/test_batch_completions.py @@ -9,10 +9,11 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm from litellm import batch_completion - +litellm.set_verbose=True messages = [[{"role": "user", "content": "Hey, how's it going"}] for _ in range(5)] print(messages[0:5]) print(len(messages)) +# model = "vllm/facebook/opt-125m" model = "gpt-3.5-turbo" result = batch_completion(model=model, messages=messages) diff --git a/pyproject.toml b/pyproject.toml index 5eaa4ff51..f48b82a47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.549" +version = "0.1.555" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"