adding support for vllm

This commit is contained in:
Krrish Dholakia 2023-09-06 18:07:42 -07:00
parent 9abefa18b8
commit 4cfcabd919
17 changed files with 163 additions and 35 deletions

BIN
dist/litellm-0.1.546-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.546.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.547-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.547.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.548-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/litellm-0.1.548.tar.gz vendored Normal file

Binary file not shown.

View file

@ -195,6 +195,7 @@ provider_list = [
"azure", "azure",
"sagemaker", "sagemaker",
"bedrock", "bedrock",
"vllm",
"custom", # custom apis "custom", # custom apis
] ]

View file

@ -54,8 +54,8 @@ def completion(
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["pre_message_sep"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["post_message_sep"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages
) )
else: else:

View file

@ -41,14 +41,13 @@ def completion(
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key)
model = model
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["pre_message_sep"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["post_message_sep"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages messages=messages
) )
else: else:

97
litellm/llms/vllm.py Normal file
View file

@ -0,0 +1,97 @@
import os
import json
from enum import Enum
import requests
import time
from typing import Callable
from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt
class VLLMError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
# check if vllm is installed
def validate_environment():
try:
from vllm import LLM, SamplingParams
return LLM, SamplingParams
except:
raise VLLMError(status_code=0, message="The vllm package is not installed in your environment. Run - `pip install vllm` before proceeding.")
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
LLM, SamplingParams = validate_environment()
try:
llm = LLM(model=model)
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": sampling_params},
)
outputs = llm.generate(prompt, sampling_params)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
return iter(outputs)
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=outputs,
additional_args={"complete_input_dict": sampling_params},
)
print_verbose(f"raw model_response: {outputs}")
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = outputs[0].outputs[0].text
## CALCULATING USAGE
prompt_tokens = len(outputs[0].prompt_token_ids)
completion_tokens = len(outputs[0].outputs[0].token_ids)
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -27,6 +27,7 @@ from .llms import huggingface_restapi
from .llms import replicate from .llms import replicate
from .llms import aleph_alpha from .llms import aleph_alpha
from .llms import baseten from .llms import baseten
from .llms import vllm
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict from typing import Callable, List, Optional, Dict
@ -670,20 +671,18 @@ def completion(
encoding=encoding, encoding=encoding,
logging_obj=logging logging_obj=logging
) )
# TODO: Add streaming for sagemaker if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
# if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object,
# # don't try to access stream object, response = CustomStreamWrapper(
# response = CustomStreamWrapper( iter(model_response), model, custom_llm_provider="sagemaker", logging_obj=logging
# model_response, model, custom_llm_provider="ai21", logging_obj=logging )
# ) return response
# return response
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "vllm":
# boto3 reads keys from .env model_response = vllm.completion(
model_response = bedrock.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -695,17 +694,15 @@ def completion(
logging_obj=logging logging_obj=logging
) )
# TODO: Add streaming for bedrock if "stream" in optional_params and optional_params["stream"] == True: ## [BETA]
# if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object,
# # don't try to access stream object, response = CustomStreamWrapper(
# response = CustomStreamWrapper( model_response, model, custom_llm_provider="vllm", logging_obj=logging
# model_response, model, custom_llm_provider="ai21", logging_obj=logging )
# ) return response
# return response
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
endpoint = ( endpoint = (
litellm.api_base if litellm.api_base is not None else api_base litellm.api_base if litellm.api_base is not None else api_base

View file

@ -435,7 +435,6 @@ def test_completion_together_ai():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_together_ai()
# def test_customprompt_together_ai(): # def test_customprompt_together_ai():
# try: # try:
# litellm.register_prompt_template( # litellm.register_prompt_template(
@ -462,6 +461,20 @@ def test_completion_sagemaker():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
######## Test VLLM ########
# def test_completion_vllm():
# try:
# response = completion(
# model="vllm/facebook/opt-125m",
# messages=messages,
# temperature=0.2,
# max_tokens=80,
# )
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_vllm()
# def test_completion_custom_api_base(): # def test_completion_custom_api_base():
# try: # try:

View file

@ -1385,23 +1385,33 @@ def modify_integration(integration_name, integration_params):
# custom prompt helper function # custom prompt helper function
def register_prompt_template(model: str, roles: dict, pre_message_sep: str, post_message_sep: str): def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""):
""" """
Example usage: Example usage:
``` ```
import litellm import litellm
litellm.register_prompt_template( litellm.register_prompt_template(
model="bloomz", model="llama-2",
roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"} roles={
pre_message_sep: "\n", "system": {
post_message_sep: "<|im_end|>\n" "pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n"
},
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"pre_message": "[INST] ",
"post_message": " [/INST]\n"
},
"assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
}
},
) )
``` ```
""" """
litellm.custom_prompt_dict[model] = { litellm.custom_prompt_dict[model] = {
"roles": roles, "roles": roles,
"pre_message_sep": pre_message_sep, "initial_prompt_value": initial_prompt_value,
"post_message_sep": post_message_sep "final_prompt_value": final_prompt_value
} }
return litellm.custom_prompt_dict return litellm.custom_prompt_dict
@ -1844,6 +1854,14 @@ def exception_type(model, original_exception, custom_llm_provider):
llm_provider="together_ai", llm_provider="together_ai",
model=model model=model
) )
elif custom_llm_provider == "vllm":
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 0:
raise APIConnectionError(
message=f"VLLMException - {original_exception.message}",
llm_provider="vllm",
model=model
)
else: else:
raise original_exception raise original_exception
except Exception as e: except Exception as e:
@ -2080,6 +2098,9 @@ class CustomStreamWrapper:
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_ai21_chunk(chunk) completion_obj["content"] = self.handle_ai21_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text
elif self.model in litellm.aleph_alpha_models: #ai21 doesn't provide streaming elif self.model in litellm.aleph_alpha_models: #ai21 doesn't provide streaming
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_aleph_alpha_chunk(chunk) completion_obj["content"] = self.handle_aleph_alpha_chunk(chunk)

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.545" version = "0.1.548"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"