diff --git a/dist/litellm-0.1.546-py3-none-any.whl b/dist/litellm-0.1.546-py3-none-any.whl new file mode 100644 index 000000000..404944357 Binary files /dev/null and b/dist/litellm-0.1.546-py3-none-any.whl differ diff --git a/dist/litellm-0.1.546.tar.gz b/dist/litellm-0.1.546.tar.gz new file mode 100644 index 000000000..83f9b0eac Binary files /dev/null and b/dist/litellm-0.1.546.tar.gz differ diff --git a/dist/litellm-0.1.547-py3-none-any.whl b/dist/litellm-0.1.547-py3-none-any.whl new file mode 100644 index 000000000..a69f6031c Binary files /dev/null and b/dist/litellm-0.1.547-py3-none-any.whl differ diff --git a/dist/litellm-0.1.547.tar.gz b/dist/litellm-0.1.547.tar.gz new file mode 100644 index 000000000..3fa99a70e Binary files /dev/null and b/dist/litellm-0.1.547.tar.gz differ diff --git a/dist/litellm-0.1.548-py3-none-any.whl b/dist/litellm-0.1.548-py3-none-any.whl new file mode 100644 index 000000000..6bffc5f2e Binary files /dev/null and b/dist/litellm-0.1.548-py3-none-any.whl differ diff --git a/dist/litellm-0.1.548.tar.gz b/dist/litellm-0.1.548.tar.gz new file mode 100644 index 000000000..76e65d718 Binary files /dev/null and b/dist/litellm-0.1.548.tar.gz differ diff --git a/litellm/__init__.py b/litellm/__init__.py index 31b167561..f3a117ed5 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -195,6 +195,7 @@ provider_list = [ "azure", "sagemaker", "bedrock", + "vllm", "custom", # custom apis ] diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 441b62dbc..cc9f665d8 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index eede7b903..1c43da22c 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 7bdd2e976..a67a6e4e9 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index e73eba191..818c845b7 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -54,8 +54,8 @@ def completion( model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["pre_message_sep"], - final_prompt_value=model_prompt_details["post_message_sep"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], messages=messages ) else: diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index 96baccb64..4f75e6e43 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -41,14 +41,13 @@ def completion( logger_fn=None, ): headers = validate_environment(api_key) - model = model if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["pre_message_sep"], - final_prompt_value=model_prompt_details["post_message_sep"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], messages=messages ) else: diff --git a/litellm/llms/vllm.py b/litellm/llms/vllm.py new file mode 100644 index 000000000..fc803ad59 --- /dev/null +++ b/litellm/llms/vllm.py @@ -0,0 +1,97 @@ +import os +import json +from enum import Enum +import requests +import time +from typing import Callable +from litellm.utils import ModelResponse +from .prompt_templates.factory import prompt_factory, custom_prompt + +class VLLMError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +# check if vllm is installed +def validate_environment(): + try: + from vllm import LLM, SamplingParams + return LLM, SamplingParams + except: + raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.") + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + custom_prompt_dict={}, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + LLM, SamplingParams = validate_environment() + try: + llm = LLM(model=model) + except Exception as e: + raise VLLMError(status_code=0, message=str(e)) + sampling_params = SamplingParams(**optional_params) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": sampling_params}, + ) + + outputs = llm.generate(prompt, sampling_params) + + + ## COMPLETION CALL + if "stream" in optional_params and optional_params["stream"] == True: + return iter(outputs) + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=outputs, + additional_args={"complete_input_dict": sampling_params}, + ) + print_verbose(f"raw model_response: {outputs}") + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text + + ## CALCULATING USAGE + prompt_tokens = len(outputs[0].prompt_token_ids) + completion_tokens = len(outputs[0].outputs[0].token_ids) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index 539ccdd12..a28a235a5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -27,6 +27,7 @@ from .llms import huggingface_restapi from .llms import replicate from .llms import aleph_alpha from .llms import baseten +from .llms import vllm import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict @@ -670,20 +671,18 @@ def completion( encoding=encoding, logging_obj=logging ) - - # TODO: Add streaming for sagemaker - # if "stream" in optional_params and optional_params["stream"] == True: - # # don't try to access stream object, - # response = CustomStreamWrapper( - # model_response, model, custom_llm_provider="ai21", logging_obj=logging - # ) - # return response - + + if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] + # don't try to access stream object, + response = CustomStreamWrapper( + iter(model_response), model, custom_llm_provider="sagemaker", logging_obj=logging + ) + return response + ## RESPONSE OBJECT response = model_response - elif custom_llm_provider == "bedrock": - # boto3 reads keys from .env - model_response = bedrock.completion( + elif custom_llm_provider == "vllm": + model_response = vllm.completion( model=model, messages=messages, model_response=model_response, @@ -695,17 +694,15 @@ def completion( logging_obj=logging ) - # TODO: Add streaming for bedrock - # if "stream" in optional_params and optional_params["stream"] == True: - # # don't try to access stream object, - # response = CustomStreamWrapper( - # model_response, model, custom_llm_provider="ai21", logging_obj=logging - # ) - # return response - + if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="vllm", logging_obj=logging + ) + return response + ## RESPONSE OBJECT response = model_response - elif custom_llm_provider == "ollama": endpoint = ( litellm.api_base if litellm.api_base is not None else api_base diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c4b97bf2d..b24385b53 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -435,7 +435,6 @@ def test_completion_together_ai(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_together_ai() # def test_customprompt_together_ai(): # try: # litellm.register_prompt_template( @@ -462,6 +461,20 @@ def test_completion_sagemaker(): except Exception as e: pytest.fail(f"Error occurred: {e}") +######## Test VLLM ######## +# def test_completion_vllm(): +# try: +# response = completion( +# model="vllm/facebook/opt-125m", +# messages=messages, +# temperature=0.2, +# max_tokens=80, +# ) +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# test_completion_vllm() # def test_completion_custom_api_base(): # try: diff --git a/litellm/utils.py b/litellm/utils.py index 63dae96cd..020cca093 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1385,23 +1385,33 @@ def modify_integration(integration_name, integration_params): # custom prompt helper function -def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str): +def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""): """ Example usage: ``` import litellm litellm.register_prompt_template( - model="bloomz", - roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"} - pre_message_sep: "\n", - post_message_sep: "<|im_end|>\n" + model="llama-2", + roles={ + "system": { + "pre_message": "[INST] <>\n", + "post_message": "\n<>\n [/INST]\n" + }, + "user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 + "pre_message": "[INST] ", + "post_message": " [/INST]\n" + }, + "assistant": { + "post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama + } + }, ) ``` """ litellm.custom_prompt_dict[model] = { "roles": roles, - "pre_message_sep": pre_message_sep, - "post_message_sep": post_message_sep + "initial_prompt_value": initial_prompt_value, + "final_prompt_value": final_prompt_value } return litellm.custom_prompt_dict @@ -1844,6 +1854,14 @@ def exception_type(model, original_exception, custom_llm_provider): llm_provider="together_ai", model=model ) + elif custom_llm_provider == "vllm": + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 0: + raise APIConnectionError( + message=f"VLLMException - {original_exception.message}", + llm_provider="vllm", + model=model + ) else: raise original_exception except Exception as e: @@ -2080,6 +2098,9 @@ class CustomStreamWrapper: elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming chunk = next(self.completion_stream) completion_obj["content"] = self.handle_ai21_chunk(chunk) + elif self.custom_llm_provider and self.custom_llm_provider == "vllm": + chunk = next(self.completion_stream) + completion_obj["content"] = chunk[0].outputs[0].text elif self.model in litellm.aleph_alpha_models: #ai21 doesn't provide streaming chunk = next(self.completion_stream) completion_obj["content"] = self.handle_aleph_alpha_chunk(chunk) diff --git a/pyproject.toml b/pyproject.toml index 1dbba1893..07eff98bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.545" +version = "0.1.548" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"